Merge pull request #18862 from sl-sergei:support_pool1d

Support for Pool1d layer for OpenCV and OpenCL targets

* Initial version of Pool1d support

* Fix variable naming

* Fix 1d pooling for OpenCL

* Change support logic, remove unnecessary variable, split the tests

* Remove other depricated variables

* Fix warning. Check tests

* Change support check logic

* Change support check logic, 2
This commit is contained in:
Sergei Slashchinin 2020-11-24 19:52:45 +03:00 committed by GitHub
parent 359ecda4fc
commit f4f462c50b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 77 deletions

View File

@ -248,8 +248,6 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
int type; int type;
std::vector<size_t> kernel_size, strides; std::vector<size_t> kernel_size, strides;
std::vector<size_t> pads_begin, pads_end; std::vector<size_t> pads_begin, pads_end;
CV_DEPRECATED_EXTERNAL Size kernel, stride, pad;
CV_DEPRECATED_EXTERNAL int pad_l, pad_t, pad_r, pad_b;
bool globalPooling; //!< Flag is true if at least one of the axes is global pooled. bool globalPooling; //!< Flag is true if at least one of the axes is global pooled.
std::vector<bool> isGlobalPooling; std::vector<bool> isGlobalPooling;
bool computeMaxIdx; bool computeMaxIdx;

View File

@ -85,8 +85,6 @@ public:
computeMaxIdx = true; computeMaxIdx = true;
globalPooling = false; globalPooling = false;
isGlobalPooling = std::vector<bool>(3, false); isGlobalPooling = std::vector<bool>(3, false);
stride = Size(1, 1);
pad_t = pad_l = pad_b = pad_r = 0;
hasDynamicShapes = params.get<bool>("has_dynamic_shapes", false); hasDynamicShapes = params.get<bool>("has_dynamic_shapes", false);
shapesInitialized = !hasDynamicShapes; shapesInitialized = !hasDynamicShapes;
@ -108,16 +106,6 @@ public:
getPoolingKernelParams(params, kernel_size, isGlobalPooling, pads_begin, pads_end, strides, padMode); getPoolingKernelParams(params, kernel_size, isGlobalPooling, pads_begin, pads_end, strides, padMode);
globalPooling = isGlobalPooling[0] || isGlobalPooling[1] || isGlobalPooling[2]; globalPooling = isGlobalPooling[0] || isGlobalPooling[1] || isGlobalPooling[2];
if (kernel_size.size() == 2) {
kernel = Size(kernel_size[1], kernel_size[0]);
stride = Size(strides[1], strides[0]);
pad = Size(pads_begin[1], pads_begin[0]);
pad_t = pads_begin[0];
pad_l = pads_begin[1];
pad_b = pads_end[0];
pad_r = pads_end[1];
}
} }
else if (params.has("pooled_w") || params.has("pooled_h")) else if (params.has("pooled_w") || params.has("pooled_h"))
{ {
@ -165,17 +153,20 @@ public:
finalKernel.push_back(isGlobalPooling[idx] ? inp[i] : kernel_size[idx]); finalKernel.push_back(isGlobalPooling[idx] ? inp[i] : kernel_size[idx]);
} }
kernel_size = finalKernel; kernel_size = finalKernel;
kernel = Size(kernel_size[1], kernel_size[0]);
} }
getConvPoolPaddings(inp, kernel_size, strides, padMode, pads_begin, pads_end); getConvPoolPaddings(inp, kernel_size, strides, padMode, pads_begin, pads_end);
if (pads_begin.size() == 2) {
pad_t = pads_begin[0]; if (inputs[0].dims == 3)
pad_l = pads_begin[1]; {
pad_b = pads_end[0]; //Pool1D
pad_r = pads_end[1]; kernel_size.erase(kernel_size.begin() + 1);
strides.erase(strides.begin() + 1);
pads_begin.erase(pads_begin.begin() + 1);
pads_end.erase(pads_end.begin() + 1);
} }
#ifdef HAVE_OPENCL #ifdef HAVE_OPENCL
poolOp.release(); poolOp.release();
#endif #endif
@ -191,9 +182,11 @@ public:
return false; return false;
if (kernel_size.size() == 3) if (kernel_size.size() == 3)
return preferableTarget == DNN_TARGET_CPU; return preferableTarget == DNN_TARGET_CPU;
if (kernel_size.size() == 1)
return false;
if (preferableTarget == DNN_TARGET_MYRIAD) { if (preferableTarget == DNN_TARGET_MYRIAD) {
#if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2019R1) #if INF_ENGINE_VER_MAJOR_LE(INF_ENGINE_RELEASE_2019R1)
if (type == MAX && (pad_l == 1 && pad_t == 1) && stride == Size(2, 2) ) { if (type == MAX && (pads_begin[1] == 1 && pads_begin[0] == 1) && (strides[0] == 2 && strides[1] == 2)) {
return !isMyriadX(); return !isMyriadX();
} }
#endif #endif
@ -205,19 +198,23 @@ public:
#endif #endif
if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{ {
return !computeMaxIdx && type != STOCHASTIC; return !computeMaxIdx && type != STOCHASTIC && kernel_size.size() > 1;
} }
else if (backendId == DNN_BACKEND_OPENCV || backendId == DNN_BACKEND_HALIDE) else if (backendId == DNN_BACKEND_OPENCV)
{ {
if (kernel_size.size() == 3) if (kernel_size.size() == 3)
return (backendId == DNN_BACKEND_OPENCV && preferableTarget == DNN_TARGET_CPU); return preferableTarget == DNN_TARGET_CPU;
if (kernel_size.empty() || kernel_size.size() == 2) if (kernel_size.size() <= 2)
return backendId == DNN_BACKEND_OPENCV || return true;
(backendId == DNN_BACKEND_HALIDE && haveHalide() &&
(type == MAX || (type == AVE && !pad_t && !pad_l && !pad_b && !pad_r)));
else else
return false; return false;
} }
else if (backendId == DNN_BACKEND_HALIDE)
{
if (kernel_size.empty() || kernel_size.size() == 2)
return haveHalide() &&
(type == MAX || (type == AVE && !pads_begin[0] && !pads_begin[1] && !pads_end[0] && !pads_end[1]));
}
return false; return false;
} }
@ -237,12 +234,25 @@ public:
config.in_shape = shape(inputs[0]); config.in_shape = shape(inputs[0]);
config.out_shape = shape(outputs[0]); config.out_shape = shape(outputs[0]);
config.kernel = kernel; if (inputs[0].dims == 3)
config.pad_l = pad_l; {
config.pad_t = pad_t; //Pool1D
config.pad_r = pad_r; config.kernel = Size(kernel_size[0], 1);
config.pad_b = pad_b; config.stride = Size(strides[0], 1);
config.stride = stride; config.pad_l = pads_begin[0];
config.pad_t = 0;
config.pad_r = pads_end[0];
config.pad_b = 0;
}
else
{
config.kernel = Size(kernel_size[1], kernel_size[0]);
config.stride = Size(strides[1], strides[0]);
config.pad_l = pads_begin[1];
config.pad_t = pads_begin[0];
config.pad_r = pads_end[1];
config.pad_b = pads_end[0];
}
config.channels = inputs[0].size[1]; config.channels = inputs[0].size[1];
config.pool_method = type == MAX ? LIBDNN_POOLING_METHOD_MAX : config.pool_method = type == MAX ? LIBDNN_POOLING_METHOD_MAX :
(type == AVE ? LIBDNN_POOLING_METHOD_AVE : (type == AVE ? LIBDNN_POOLING_METHOD_AVE :
@ -428,7 +438,6 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
public: public:
const Mat* src, *rois; const Mat* src, *rois;
Mat *dst, *mask; Mat *dst, *mask;
Size kernel, stride;
int pad_l, pad_t, pad_r, pad_b; int pad_l, pad_t, pad_r, pad_b;
bool avePoolPaddedArea; bool avePoolPaddedArea;
int nstripes; int nstripes;
@ -453,7 +462,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
CV_Assert_N( CV_Assert_N(
src.isContinuous(), dst.isContinuous(), src.isContinuous(), dst.isContinuous(),
src.type() == CV_32F, src.type() == dst.type(), src.type() == CV_32F, src.type() == dst.type(),
src.dims == 4 || src.dims == 5, dst.dims == 4 || dst.dims == 5, src.dims == 3 || src.dims == 4 || src.dims == 5, dst.dims == 3 || dst.dims == 4 || dst.dims == 5,
(((poolingType == ROI || poolingType == PSROI) && (((poolingType == ROI || poolingType == PSROI) &&
dst.size[0] == rois.size[0]) || src.size[0] == dst.size[0]), dst.size[0] == rois.size[0]) || src.size[0] == dst.size[0]),
poolingType == PSROI || src.size[1] == dst.size[1], poolingType == PSROI || src.size[1] == dst.size[1],
@ -461,6 +470,9 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
PoolingInvoker p; PoolingInvoker p;
bool isPool1D = src.dims == 3;
bool isPool3D = src.dims == 5;
p.src = &src; p.src = &src;
p.rois = &rois; p.rois = &rois;
p.dst = &dst; p.dst = &dst;
@ -471,12 +483,10 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
p.pads_end = pads_end; p.pads_end = pads_end;
p.mask = &mask; p.mask = &mask;
p.kernel = Size(kernel_size[1], kernel_size[0]);
p.stride = Size(strides[1], strides[0]);
p.pad_l = pads_begin.back(); p.pad_l = pads_begin.back();
p.pad_t = pads_begin[pads_begin.size() - 2]; p.pad_t = isPool1D ? 0 : pads_begin[pads_begin.size() - 2];
p.pad_r = pads_end.back(); p.pad_r = pads_end.back();
p.pad_b = pads_end[pads_end.size() - 2]; p.pad_b = isPool1D ? 0 : pads_end[pads_end.size() - 2];
p.avePoolPaddedArea = avePoolPaddedArea; p.avePoolPaddedArea = avePoolPaddedArea;
p.nstripes = nstripes; p.nstripes = nstripes;
@ -486,11 +496,11 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
if( !computeMaxIdx ) if( !computeMaxIdx )
{ {
int height = src.size[src.dims - 2]; int height = isPool1D ? 1 : src.size[src.dims - 2];
int width = src.size[src.dims - 1]; int width = src.size[src.dims - 1];
int kernel_d = (kernel_size.size() == 3) ? kernel_size[0] : 1; int kernel_d = isPool3D ? kernel_size[0] : 1;
int kernel_h = kernel_size[kernel_size.size() - 2]; int kernel_h = isPool1D ? 1 : kernel_size[kernel_size.size() - 2];
int kernel_w = kernel_size.back(); int kernel_w = kernel_size.back();
p.ofsbuf.resize(kernel_d * kernel_h * kernel_w); p.ofsbuf.resize(kernel_d * kernel_h * kernel_w);
@ -510,13 +520,15 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
{ {
int channels = dst->size[1]; int channels = dst->size[1];
bool isPool3D = src->dims == 5;
bool isPool2D = src->dims == 4; bool isPool2D = src->dims == 4;
int depth = !isPool2D? dst->size[2] : 1; bool isPool1D = src->dims == 3;
int height = dst->size[dst->dims - 2]; int depth = isPool3D? dst->size[2] : 1;
int height = isPool1D? 1 : dst->size[dst->dims - 2];
int width = dst->size[dst->dims - 1]; int width = dst->size[dst->dims - 1];
int inp_depth = !isPool2D? src->size[2] : 1; int inp_depth = isPool3D? src->size[2] : 1;
int inp_height = src->size[src->dims - 2]; int inp_height = isPool1D? 1 : src->size[src->dims - 2];
int inp_width = src->size[src->dims - 1]; int inp_width = src->size[src->dims - 1];
size_t total = dst->total(); size_t total = dst->total();
@ -524,12 +536,12 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
size_t stripeStart = r.start*stripeSize; size_t stripeStart = r.start*stripeSize;
size_t stripeEnd = std::min(r.end*stripeSize, total); size_t stripeEnd = std::min(r.end*stripeSize, total);
int kernel_d = !isPool2D? kernel_size[0] : 1; int kernel_d = isPool3D? kernel_size[0] : 1;
int kernel_h = kernel_size[kernel_size.size() - 2]; int kernel_h = isPool1D? 1 : kernel_size[kernel_size.size() - 2];
int kernel_w = kernel_size.back(); int kernel_w = kernel_size.back();
int stride_d = !isPool2D? strides[0] : 0; int stride_d = isPool3D? strides[0] : 0;
int stride_h = strides[strides.size() - 2]; int stride_h = isPool1D? 1 :strides[strides.size() - 2];
int stride_w = strides.back(); int stride_w = strides.back();
bool compMaxIdx = computeMaxIdx; bool compMaxIdx = computeMaxIdx;
@ -720,7 +732,24 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
} }
} }
else else
#else
CV_UNUSED(isPool2D);
#endif #endif
if( isPool1D )
{
const float* first = srcData + xstart;
const float* last = srcData + xend;
const float* max_elem = std::max_element(first, last);
if (max_elem!=last)
{
dstData[x0] = *max_elem;
if( compMaxIdx )
{
dstMaskData[x0] = std::distance(first, max_elem);
}
}
}
else
{ {
float max_val = -FLT_MAX; float max_val = -FLT_MAX;
if( compMaxIdx ) if( compMaxIdx )
@ -794,6 +823,14 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
} }
else else
#endif #endif
if( isPool1D )
{
const float* first = srcData + xstart;
const float* last = srcData + xend;
float sum_val = std::accumulate(first, last, 0.f);
dstData[x0] = sum_val*inv_kernel_area;
}
else
{ {
float sum_val = 0.f; float sum_val = 0.f;
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
@ -907,20 +944,26 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]); Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
const int inWidth = inputBuffer.width(); const int inWidth = inputBuffer.width();
const int inHeight = inputBuffer.height(); const int inHeight = inputBuffer.height();
const size_t kernelHeight = kernel_size[0];
const size_t kernelWidth = kernel_size[1];
const size_t strideHeight = strides[0];
const size_t strideWidth = strides[1];
const size_t paddingTop = pads_begin[0];
const size_t paddingLeft = pads_begin[1];
Halide::Var x("x"), y("y"), c("c"), n("n"); Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name)); Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::RDom r(0, kernel.width, 0, kernel.height); Halide::RDom r(0, kernelWidth, 0, kernelHeight);
Halide::Expr kx, ky; Halide::Expr kx, ky;
if(pad_l || pad_t) if(paddingLeft || paddingTop)
{ {
kx = clamp(x * stride.width + r.x - pad_l, 0, inWidth - 1); kx = clamp(x * strideWidth + r.x - paddingLeft, 0, inWidth - 1);
ky = clamp(y * stride.height + r.y - pad_t, 0, inHeight - 1); ky = clamp(y * strideHeight + r.y - paddingTop, 0, inHeight - 1);
} }
else else
{ {
kx = min(x * stride.width + r.x, inWidth - 1); kx = min(x * strideWidth + r.x, inWidth - 1);
ky = min(y * stride.height + r.y, inHeight - 1); ky = min(y * strideHeight + r.y, inHeight - 1);
} }
// Halide::argmax returns tuple (r.x, r.y, max). // Halide::argmax returns tuple (r.x, r.y, max).
@ -928,17 +971,17 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
// Compute offset from argmax in range [0, kernel_size). // Compute offset from argmax in range [0, kernel_size).
Halide::Expr max_index; Halide::Expr max_index;
if(pad_l || pad_t) if(paddingLeft || paddingTop)
{ {
max_index = clamp(y * stride.height + res[1] - pad_t, max_index = clamp(y * strideHeight + res[1] - paddingTop,
0, inHeight - 1) * inWidth + 0, inHeight - 1) * inWidth +
clamp(x * stride.width + res[0] - pad_l, clamp(x * strideWidth + res[0] - paddingLeft,
0, inWidth - 1); 0, inWidth - 1);
} }
else else
{ {
max_index = min(y * stride.height + res[1], inHeight - 1) * inWidth + max_index = min(y * strideHeight + res[1], inHeight - 1) * inWidth +
min(x * stride.width + res[0], inWidth - 1); min(x * strideWidth + res[0], inWidth - 1);
} }
top(x, y, c, n) = { res[2], Halide::cast<float>(max_index) }; top(x, y, c, n) = { res[2], Halide::cast<float>(max_index) };
return Ptr<BackendNode>(new HalideBackendNode(top)); return Ptr<BackendNode>(new HalideBackendNode(top));
@ -952,21 +995,25 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]); Halide::Buffer<float> inputBuffer = halideBuffer(inputs[0]);
const int inW = inputBuffer.width(), inH = inputBuffer.height(); const int inW = inputBuffer.width(), inH = inputBuffer.height();
if ((inW - kernel.width) % stride.width || (inH - kernel.height) % stride.height) const size_t kernelHeight = kernel_size[0];
const size_t kernelWidth = kernel_size[1];
const size_t strideHeight = strides[0];
const size_t strideWidth = strides[1];
if ((inW - kernelWidth) % strideWidth || (inH - kernelHeight) % strideHeight)
{ {
CV_Error(cv::Error::StsNotImplemented, CV_Error(cv::Error::StsNotImplemented,
"Halide backend for average pooling with partial " "Halide backend for average pooling with partial "
"kernels is not implemented"); "kernels is not implemented");
} }
const float norm = 1.0f / (kernel.width * kernel.height); const float norm = 1.0f / (kernelWidth * kernelHeight);
Halide::Var x("x"), y("y"), c("c"), n("n"); Halide::Var x("x"), y("y"), c("c"), n("n");
Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name)); Halide::Func top = (name.empty() ? Halide::Func() : Halide::Func(name));
Halide::RDom r(0, kernel.width, 0, kernel.height); Halide::RDom r(0, kernelWidth, 0, kernelHeight);
top(x, y, c, n) = sum( top(x, y, c, n) = sum(
inputBuffer(x * stride.width + r.x, inputBuffer(x * strideWidth + r.x,
y * stride.height + r.y, c, n)) * norm; y * strideHeight + r.y, c, n)) * norm;
return Ptr<BackendNode>(new HalideBackendNode(top)); return Ptr<BackendNode>(new HalideBackendNode(top));
#endif // HAVE_HALIDE #endif // HAVE_HALIDE
return Ptr<BackendNode>(); return Ptr<BackendNode>();
@ -1028,6 +1075,7 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
{ {
CV_Assert(inputs.size() != 0); CV_Assert(inputs.size() != 0);
bool isPool1D = inputs[0].size() == 3;
std::vector<int> inpShape(inputs[0].begin() + 2, inputs[0].end()); std::vector<int> inpShape(inputs[0].begin() + 2, inputs[0].end());
std::vector<int> outShape(inputs[0].begin(), inputs[0].begin() + 2); std::vector<int> outShape(inputs[0].begin(), inputs[0].begin() + 2);
@ -1056,14 +1104,15 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
} }
else if (padMode.empty()) else if (padMode.empty())
{ {
for (int i = 0; i < local_kernel.size(); i++) { int addedDims = isPool1D? inpShape.size() : local_kernel.size();
for (int i = 0; i < addedDims; i++) {
float dst = (float) (inpShape[i] + pads_begin[i] + pads_end[i] - local_kernel[i]) / strides[i]; float dst = (float) (inpShape[i] + pads_begin[i] + pads_end[i] - local_kernel[i]) / strides[i];
outShape.push_back(1 + (ceilMode ? ceil(dst) : floor(dst))); outShape.push_back(1 + (ceilMode ? ceil(dst) : floor(dst)));
} }
// If we have padding, ensure that the last pooling starts strictly // If we have padding, ensure that the last pooling starts strictly
// inside the image (instead of at the padding); otherwise clip the last. // inside the image (instead of at the padding); otherwise clip the last.
for (int i = 0; i < pads_end.size(); i++) { for (int i = 0; i < addedDims; i++) {
if (pads_end[i] && (outShape[2 + i] - 1) * strides[i] >= inpShape[i] + pads_end[i]) { if (pads_end[i] && (outShape[2 + i] - 1) * strides[i] >= inpShape[i] + pads_end[i]) {
--outShape[2 + i]; --outShape[2 + i];
CV_Assert((outShape[2 + i] - 1) * strides[i] < inpShape[i] + pads_end[i]); CV_Assert((outShape[2 + i] - 1) * strides[i] < inpShape[i] + pads_end[i]);
@ -1107,7 +1156,8 @@ virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inp
{ {
CV_UNUSED(inputs); // suppress unused variable warning CV_UNUSED(inputs); // suppress unused variable warning
long flops = 0; long flops = 0;
size_t karea = std::accumulate(kernel_size.begin(), kernel_size.end(), bool isPool1D = inputs[0].size() == 3;
size_t karea = std::accumulate(kernel_size.begin(), isPool1D? kernel_size.begin() + 1 : kernel_size.end(),
1, std::multiplies<size_t>()); 1, std::multiplies<size_t>());
for(int i = 0; i < outputs.size(); i++) for(int i = 0; i < outputs.size(); i++)
{ {

View File

@ -51,18 +51,20 @@ template<typename Dtype>
OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config) OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config)
{ {
int dims = config.in_shape.size(); int dims = config.in_shape.size();
int spatial_dims = 2; int spatial_dims = config.in_shape.size()-2;
channels_ = config.channels; channels_ = config.channels;
pool_method_ = config.pool_method; pool_method_ = config.pool_method;
avePoolPaddedArea = config.avePoolPaddedArea; avePoolPaddedArea = config.avePoolPaddedArea;
computeMaxIdx = config.computeMaxIdx; computeMaxIdx = config.computeMaxIdx;
use_half = config.use_half; use_half = config.use_half;
kernel_shape_.push_back(config.kernel.height);
kernel_shape_.push_back(config.kernel.width);
stride_.push_back(config.stride.height);
stride_.push_back(config.stride.width);
for (int i = 0; i < spatial_dims; ++i) for (int i = 0; i < spatial_dims; ++i)
{ {
kernel_shape_.push_back(i == 0 ? config.kernel.height : config.kernel.width);
stride_.push_back(i == 0 ? config.stride.height : config.stride.width);
im_in_shape_.push_back(config.in_shape[dims - spatial_dims + i]); im_in_shape_.push_back(config.in_shape[dims - spatial_dims + i]);
im_out_shape_.push_back(config.out_shape[dims - spatial_dims + i]); im_out_shape_.push_back(config.out_shape[dims - spatial_dims + i]);
} }
@ -75,10 +77,10 @@ OCL4DNNPool<Dtype>::OCL4DNNPool(OCL4DNNPoolConfig config)
pad_l_ = config.pad_l; pad_l_ = config.pad_l;
pad_r_ = config.pad_r; pad_r_ = config.pad_r;
pad_b_ = config.pad_b; pad_b_ = config.pad_b;
height_ = im_in_shape_[0]; height_ = spatial_dims == 1? 1 : im_in_shape_[0];
width_ = im_in_shape_[1]; width_ = im_in_shape_.back();
pooled_height_ = im_out_shape_[0]; pooled_height_ = spatial_dims == 1? 1 : im_out_shape_[0];
pooled_width_ = im_out_shape_[1]; pooled_width_ = im_out_shape_.back();
count_ = 1; count_ = 1;
for (int i = 0; i < config.out_shape.size(); ++i) for (int i = 0; i < config.out_shape.size(); ++i)

View File

@ -747,6 +747,84 @@ TEST_P(Test_ONNX_layers, DynamicAxes)
testONNXModels("maxpooling_sigmoid_dynamic_axes"); testONNXModels("maxpooling_sigmoid_dynamic_axes");
} }
TEST_P(Test_ONNX_layers, MaxPool1d)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
}
testONNXModels("maxpooling_1d");
}
TEST_P(Test_ONNX_layers, MaxPoolSigmoid1d)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
}
testONNXModels("maxpooling_sigmoid_1d");
}
TEST_P(Test_ONNX_layers, MaxPool1d_Twise)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
}
testONNXModels("two_maxpooling_1d");
}
TEST_P(Test_ONNX_layers, AvePool1d)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
}
testONNXModels("average_pooling_1d");
}
TEST_P(Test_ONNX_layers, PoolConv1d)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
}
testONNXModels("pool_conv_1d");
}
TEST_P(Test_ONNX_layers, ConvResizePool1d)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
}
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
{
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
}
testONNXModels("conv_resize_pool_1d");
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
class Test_ONNX_nets : public Test_ONNX_layers class Test_ONNX_nets : public Test_ONNX_layers