From ed69bcae2d171d9426cd3688a8b0ee14b8a140cd Mon Sep 17 00:00:00 2001 From: rogday Date: Tue, 19 Jul 2022 06:14:05 +0300 Subject: [PATCH] Merge pull request #21865 from rogday:nary_eltwise_layers Reimplementation of Element-wise layers with broadcasting support * init * semi-working initial version * add small_vector * wip * remove smallvec * add nary function * replace auto with Mat in lambda expr used in transform * uncomment asserts * autobuffer shape_buf & step_buf * fix a missing bracket * fixed a missing addLayer in parseElementWise * solve one-dimensional broadcast * remove pre_broadcast_transform for the case of two constants; fix missing constBlobsExtraInfo when addConstant is called * one autobuffer for step & shape * temporal fix for the missing original dimension information * fix parseUnsqueeze when it gets a 1d tensor constant * support sum/mean/min/max with only one input * reuse old code to handle cases of two non-constant inputs * add condition to handle div & mul of two non-constant inputs * use || instead of or * remove trainling spaces * enlarge buf in binary_forward to contain other buffer * use autobuffer in nary_forward * generate data randomly and add more cases for perf * add op and, or & xor * update perf_dnn * remove some comments * remove legacy; add two ONNX conformance tests in filter * move from cpu_denylist to all_denylist * adjust parsing for inputs>=2 Co-authored-by: fengyuentau --- .../dnn/include/opencv2/dnn/all_layers.hpp | 6 + modules/dnn/perf/perf_layer.cpp | 150 ++++ modules/dnn/src/init.cpp | 1 + .../dnn/src/layers/nary_eltwise_layers.cpp | 664 ++++++++++++++++++ modules/dnn/src/onnx/onnx_importer.cpp | 468 ++++-------- ...e_layer_filter_opencv_all_denylist.inl.hpp | 2 + ...e_layer_filter_opencv_cpu_denylist.inl.hpp | 1 + 7 files changed, 952 insertions(+), 340 deletions(-) create mode 100644 modules/dnn/src/layers/nary_eltwise_layers.cpp diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 5c86da2be4..6ecf85da7c 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -849,6 +849,12 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS NaryEltwiseLayer : public Layer + { + public: + static Ptr create(const LayerParams ¶ms); + }; + class CV_EXPORTS BatchNormLayer : public ActivationLayer { public: diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index 06fa57f319..03ba8ab0e9 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -55,7 +55,156 @@ struct Layer_Slice : public TestBaseWithParam > } }; +struct Layer_NaryEltwise : public TestBaseWithParam > +{ + void test_layer(const std::vector& a_shape, const std::vector& b_shape, const String op, bool isRef = false) + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + Mat a(a_shape, CV_32FC1); + Mat b(b_shape, CV_32FC1); + + Scalar mean = 0.f; + Scalar std = 1.f; + randn(a, mean, std); + randn(b, mean, std); + + + Net net; + LayerParams lp; + if (isRef) + lp.type = "Eltwise"; + else + lp.type = "NaryEltwise"; + lp.name = "testLayer"; + lp.set("operation", op); + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 1, id, 1); + + // warmup + { + std::vector inpNames(2); + inpNames[0] = "a"; + inpNames[1] = "b"; + net.setInputsNames(inpNames); + net.setInput(a, inpNames[0]); + net.setInput(b, inpNames[1]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 8; + int C = 256; + int H = 128; + int W = 100; +}; + + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_add) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "add"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_div) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "div"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_ref_div) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "div", true); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_equal) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "equal"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_greater) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "greater"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_less) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "less"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_max) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "max"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_ref_max) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "max", true); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_mean) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "mean"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_min) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "min"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_ref_min) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "min", true); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_mul) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "mul"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_ref_mul) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "prod", true); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_pow) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "pow"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_sub) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "sub"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_sum) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "sum"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_NCHW_ref_sum) +{ + test_layer({N, C, H, W}, {N, C, H, W}, "sum", true); +} + +PERF_TEST_P_(Layer_NaryEltwise, NCHW_C_sum) +{ + test_layer({N, C, H, W}, {C, 1, 1}, "sum"); +} + +PERF_TEST_P_(Layer_NaryEltwise, NHWC_C) +{ + test_layer({N, H, W, C}, {1, C}, "sum"); +} PERF_TEST_P_(Layer_Slice, YOLOv4_tiny_1) { @@ -91,5 +240,6 @@ PERF_TEST_P_(Layer_Slice, FastNeuralStyle_eccv16) } INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false)); +INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); } // namespace diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index 6979d1864d..e3ce6de40d 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -150,6 +150,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer); CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer); + CV_DNN_REGISTER_LAYER_CLASS(NaryEltwise, NaryEltwiseLayer); CV_DNN_REGISTER_LAYER_CLASS(Permute, PermuteLayer); CV_DNN_REGISTER_LAYER_CLASS(ShuffleChannel, ShuffleChannelLayer); CV_DNN_REGISTER_LAYER_CLASS(PriorBox, PriorBoxLayer); diff --git a/modules/dnn/src/layers/nary_eltwise_layers.cpp b/modules/dnn/src/layers/nary_eltwise_layers.cpp new file mode 100644 index 0000000000..db37f16060 --- /dev/null +++ b/modules/dnn/src/layers/nary_eltwise_layers.cpp @@ -0,0 +1,664 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" +#include + +#include +#include +#include + +namespace cv +{ +namespace dnn +{ + +class NaryEltwiseLayerImpl CV_FINAL : public NaryEltwiseLayer +{ +public: + enum class OPERATION + { + AND = 0, + EQUAL, + GREATER, + GREATER_EQUAL, + LESS, + LESS_EQUAL, + OR, + POW, + XOR, + BITSHIFT, + MAX, + MEAN, + MIN, + MOD, + PROD, + SUB, + SUM, + ADD, + DIV, + } op; + + NaryEltwiseLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + String operation = toLowerCase(params.get("operation", "sum")); + + if (operation == "equal") + op = OPERATION::EQUAL; + else if (operation == "greater") + op = OPERATION::GREATER; + else if (operation == "greater_equal") + op = OPERATION::GREATER_EQUAL; + else if (operation == "less") + op = OPERATION::LESS; + else if (operation == "less_equal") + op = OPERATION::LESS_EQUAL; + else if (operation == "pow") + op = OPERATION::POW; + else if (operation == "bitshift") + op = OPERATION::BITSHIFT; + else if (operation == "max") + op = OPERATION::MAX; + else if (operation == "mean") + op = OPERATION::MEAN; + else if (operation == "min") + op = OPERATION::MIN; + else if (operation == "mod") + op = OPERATION::MOD; + else if (operation == "mul") + op = OPERATION::PROD; + else if (operation == "sub") + op = OPERATION::SUB; + else if (operation == "sum") + op = OPERATION::SUM; + else if (operation == "add") + op = OPERATION::ADD; + else if (operation == "div") + op = OPERATION::DIV; + else if (operation == "and") + op = OPERATION::AND; + else if (operation == "or") + op = OPERATION::OR; + else if (operation == "xor") + op = OPERATION::XOR; + else + CV_Error(cv::Error::StsBadArg, "Unknown operation type \"" + operation + "\""); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV; + } + + static MatShape findCommonShape(std::vector shapes) + { + CV_Assert(!shapes.empty()); + const size_t dim = std::max_element(shapes.begin(), shapes.end(), + [](const MatShape& a, const MatShape& b) + { return a.size() < b.size(); })->size(); + + for (auto& shape : shapes) + { + shape.insert(shape.begin(), dim - shape.size(), 1); + } + + MatShape outShape(dim, 1); + for (size_t i = 0; i < dim; ++i) + { + for (const auto& shape : shapes) + { + if (shape[i] != outShape[i]) + { + CV_Assert(shape[i] == 1 || outShape[i] == 1); + outShape[i] = std::max(outShape[i], shape[i]); + } + } + } + + return outShape; + } + + static bool prepare_for_broadcast_op( + int narrays, int max_ndims, const size_t* elemsize, + const int* ndims, const int** shape_, const size_t** step_, + int** shape, size_t** step) + { + int i, j, k; + + // step 1. + // * make all inputs and the output max_ndims-dimensional. + // ** prepend dimension 1 to the mat of less dims + // * compute proper step's + for (i = max_ndims-1; i >= 0; i-- ) { + for (k = 0; k < narrays; k++) { + j = ndims[k] - (max_ndims - i); + int sz_i = j >= 0 ? shape_[k][j] : 1; + size_t st_i = j >= 0 && step_ && step_[k] && step_[k][j] > 0 ? step_[k][j] : + i == max_ndims-1 ? elemsize[k] : step[k][i+1]*shape[k][i+1]; + assert(st_i % elemsize[k] == 0); + shape[k][i] = sz_i; + step[k][i] = st_i; + if (shape[k][i] == 0) + return false; + } + } + + // step 3. Let's do the flattening first, + // since we'd need proper values of steps to check continuity. + // this loop is probably the most tricky part + // in the whole implementation of broadcasting. + j = max_ndims-1; + for (i = j - 1; i >= 0; i--) { + bool all_contiguous = true, all_scalars = true, all_consistent = true; + for(k = 0; k < narrays; k++) { + size_t st = step[k][j]*shape[k][j]; + bool prev_scalar = shape[k][j] == 1; + bool scalar = shape[k][i] == 1; + all_contiguous = all_contiguous && (st == step[k][i]); + all_scalars = all_scalars && scalar; + all_consistent = all_consistent && (scalar == prev_scalar); + } + if (all_contiguous && (all_consistent || all_scalars)) { + for(k = 0; k < narrays; k++) + shape[k][j] *= shape[k][i]; + } else { + j--; + if (i < j) { + for(k = 0; k < narrays; k++) { + shape[k][j] = shape[k][i]; + step[k][j] = step[k][i]; + } + } + } + } + + // step 2. Set some step's to 0's. + for (i = max_ndims-1; i >= j; i--) { + for (k = 0; k < narrays; k++) + step[k][i] = shape[k][i] == 1 ? 0 : step[k][i]; + } + for (; i >= 0; i--) { + for (k = 0; k < narrays; k++) { + step[k][i] = 0; + shape[k][i] = 1; + } + } + return true; + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + MatShape outShape = findCommonShape(inputs); + outputs.assign(1, outShape); + return false; + } + + template + void binary_forward_impl( + int ndims, const int* shape, + const char* data1, const size_t* step1, + const char* data2, const size_t* step2, + char* data, const size_t* step, + const Functor& op) + { + assert(ndims >= 2); + size_t dp1 = step1[ndims-1]/sizeof(T); + size_t dp2 = step2[ndims-1]/sizeof(T); + size_t dp = step[ndims-1]/sizeof(T); + int k, n1 = shape[ndims-1], n2 = shape[ndims-2]; + size_t plane_idx, nplanes = 1; + for (k = 0; k < ndims-2; k++) nplanes *= shape[k]; + + for (plane_idx = 0; plane_idx < nplanes; plane_idx++) { + const char* ptr1_ = data1; + const char* ptr2_ = data2; + char* ptr_ = data; + size_t idx = plane_idx; + for (k = ndims-3; k >= 0; k--) { + size_t next_idx = idx/shape[k]; + int i_k = (int)(idx - next_idx*shape[k]); + ptr1_ += i_k*step1[k]; + ptr2_ += i_k*step2[k]; + ptr_ += i_k*step[k]; + idx = next_idx; + } + for (int i2 = 0; i2 < n2; i2++, ptr1_ += step1[ndims-2], + ptr2_ += step2[ndims-2], + ptr_ += step[ndims-2]) + { + const T* ptr1 = (const T*)ptr1_; + const T* ptr2 = (const T*)ptr2_; + T* ptr = (T*)ptr_; + if (dp1 == 1 && dp2 == 1 && dp == 1) { + for(int i1 = 0; i1 < n1; i1++) + ptr[i1] = op(ptr1[i1], ptr2[i1]); + } else if (dp1 == 1 && dp2 == 0 && dp == 1){ + T x2 = *ptr2; + for(int i1 = 0; i1 < n1; i1++) + ptr[i1] = op(ptr1[i1], x2); + } else if (dp1 == 0 && dp2 == 1 && dp == 1){ + T x1 = *ptr1; + for(int i1 = 0; i1 < n1; i1++) + ptr[i1] = op(x1, ptr2[i1]); + } else { + for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr += dp) + *ptr = op(*ptr1, *ptr2); + } + } + } + } + + template + void binary_forward(const Functor& f, const std::vector& inputs, std::vector& outputs) + { + const Mat& a = inputs[0]; + const Mat& b = inputs[1]; + Mat& out = outputs[0]; + + // collect info of inputs and output + const int* in_shape[] = {a.size.p, b.size.p}; + const size_t* in_step[] = {a.step.p, b.step.p}; + const int* out_shape = out.size.p; + const size_t* out_step = out.step.p; + const int in_ndims[] = {a.dims, b.dims}; + int out_ndims = out.dims; + + int max_ndims = std::max(a.dims, std::max(b.dims, out.dims)); + + // buf holds the folllowing for a, b & output: + // * orig_shapes, shapes (result_shape), orig_steps, steps (result_step), 3*4 elements in total + // * shape_buf & step_buf, 3*2*max_ndims elements in total + // * all_ndims, 3*1 elements in total + // * all_type_sizes, 3*1 elements in total + AutoBuffer buf(3 * (2 * max_ndims + 6)); + + int** orig_shapes = (int**)(buf.data()); + int** shapes = orig_shapes + 3; + size_t** orig_steps = (size_t**)(shapes + 3); + size_t** steps = orig_steps + 3; + + int* shape_buf = (int*)(steps + 3); + size_t* step_buf = (size_t*)(shape_buf + 3 * max_ndims); + + int* all_ndims = (int*)(step_buf + 3 * max_ndims); + size_t* all_type_sizes = (size_t*)(all_ndims + 3); + + // assign orig_shapes, shapes, orig_steps, steps, all_ndims, all_type_sizes + for (int i = 0; i < 3; i++) + { + orig_shapes[i] = (int*)(i == 0 ? out_shape : in_shape[i-1]); + orig_steps[i] = (size_t*)(i == 0 ? out_step : in_step[i-1]); + shapes[i] = shape_buf + i * max_ndims; + steps[i] = step_buf + i * max_ndims; + all_ndims[i] = i == 0 ? out_ndims : in_ndims[i-1]; + all_type_sizes[i] = sizeof(T); + } + + if (!prepare_for_broadcast_op(3, max_ndims, all_type_sizes, + all_ndims, (const int**)orig_shapes, + (const size_t**)orig_steps, + shapes, steps)) + return; + + binary_forward_impl( + max_ndims, shapes[0], a.ptr(), steps[1], + b.ptr(), steps[2], out.ptr(), steps[0], + f); + } + + template + void nary_forward_impl( + const Functor& f, const T scale, int ninputs, int ndims, const int* shape, + const char** inp, char* out, + const size_t** steps, char** ptrs) + { + CV_Assert(ndims >= 2); + size_t dp = steps[0][ndims-1]/sizeof(T); + size_t dp1 = steps[1][ndims-1]/sizeof(T); + size_t dp2 = steps[2][ndims-1]/sizeof(T); + + CV_Assert(dp == 1); + enum { BLOCK_SIZE = 1024 }; + T blck[BLOCK_SIZE]; + + int k, i, di1=0, n1 = shape[ndims-1], n2 = shape[ndims-2]; + int second = ninputs == 1 ? 1 : 2; + size_t plane_idx, nplanes = 1; + for (k = 0; k < ndims-2; k++) nplanes *= shape[k]; + + for (plane_idx = 0; plane_idx < nplanes; plane_idx++) { + ptrs[0] = out; + for (i = 0; i < ninputs; i++) ptrs[i+1] = (char*)inp[i]; + size_t idx = plane_idx; + for (k = ndims-3; k >= 0; k--) { + size_t next_idx = idx/shape[k]; + int i_k = (int)(idx - next_idx*shape[k]); + for (i = 0; i < ninputs; i++) + ptrs[i] += i_k*steps[i][k]; + idx = next_idx; + } + for (int i2 = 0; i2 < n2; i2++) + { + const T* ptr1 = (const T*)(ptrs[1] + steps[1][ndims-2]*i2); + const T* ptr2 = (const T*)(ptrs[second] + steps[second][ndims-2]*i2); + T* ptr = (T*)(ptrs[0] + steps[0][ndims-2]*i2); + if (ninputs <= 2) { + if (dp1 == 1 && dp2 == 1) { + for (int i1 = 0; i1 < n1; i1++) + ptr[i1] = saturate_cast(f(ptr1[i1], ptr2[i1])*scale); + } else { + for(int i1 = 0; i1 < n1; i1++, ptr1 += dp1, ptr2 += dp2, ptr += dp) + *ptr = saturate_cast(f(*ptr1, *ptr2)*scale); + } + } else { + for (int i1 = 0; i1 < n1; i1 += di1, ptr += di1) { + di1 = BLOCK_SIZE < n1-i1 ? BLOCK_SIZE : n1-i1; + if (dp1 == 1 && dp2 == 1) { + for (int j = 0; j < di1; j++) + blck[j] = f(ptr1[j], ptr2[j]); + ptr1 += di1; + ptr2 += di1; + } else { + for(int j = 0; j < di1; j++, ptr1 += dp1, ptr2 += dp2) + blck[j] = f(*ptr1, *ptr2); + } + for(i = 2; i < ninputs; i++) { + int dp_i = steps[i+1][ndims-1]/sizeof(T); + const T* ptr_i = (const T*)(ptrs[i+1] + + steps[i+1][ndims-2]*i2) + i1*dp_i; + if (dp_i == 1) { + if (i < ninputs-1) { + for (int j = 0; j < di1; j++) + blck[j] = f(blck[j], ptr_i[j]); + } else { + for (int j = 0; j < di1; j++) + ptr[j] = saturate_cast(f(blck[j], ptr_i[j]) * scale); + } + } else { + if (i < ninputs-1) { + for (int j = 0; j < di1; j++, ptr_i += dp_i) + blck[j] = f(blck[j], *ptr_i); + } else { + for (int j = 0; j < di1; j++, ptr_i += dp_i) + ptr[j] = saturate_cast(f(blck[j], *ptr_i) * scale); + } + } + } + } + } + } + } + } + + template + void nary_forward( + const Functor& f, T scale, + const std::vector& inputs, std::vector& outputs + ) + { + int ninputs = inputs.size(); + + // collect all input + std::vector v_inp; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(v_inp), [] (const Mat& m) { return m.template ptr(); }); + const char** inp = v_inp.data(); + + // collect ndims of all input + std::vector v_inp_dims; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(v_inp_dims), [] (const Mat& m) { return m.dims; }); + const int* inp_ndims = v_inp_dims.data(); + + // collect shapes of all input + std::vector v_inp_shape; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(v_inp_shape), [] (const Mat& m) { return m.size.p; }); + const int** inp_shape = v_inp_shape.data(); + + // collect steps of all input + std::vector v_inp_step; + std::transform(inputs.begin(), inputs.end(), std::back_inserter(v_inp_step), [] (const Mat& m) { return m.step.p; }); + const size_t** inp_step = v_inp_step.data(); + + // collect info of output (ndims, shape, step) + char* out = outputs[0].ptr(); + int out_ndims = outputs[0].dims; + const int* out_shape = outputs[0].size.p; + const size_t* out_step = outputs[0].step.p; + + // find max ndims for broadcasting + int i, max_ndims = out_ndims > 2 ? out_ndims : 2; + for(i = 0; i < ninputs; i++) + max_ndims = max_ndims > inp_ndims[i] ? max_ndims : inp_ndims[i]; + + // buf holds the following buffers for inputs & output: + // * orig_shapes, shapes (result_shape), orig_steps, steps (result_step), (ninputs+1)*4 elements in total + // * ptrs, (ninputs+1)*1 elements in total + // * shape_buf & step_buf, (ninputs+1)*2*max_ndims elements in total + // * all_ndims, (ninputs+1)*1 elements in total + // * all_type_sizes, (ninputs+1)*1 elements in total + AutoBuffer buf((ninputs + 1) * (2 * max_ndims + 7)); + + int** orig_shapes = (int**)buf.data(); + int** shapes = orig_shapes + ninputs + 1; + size_t** orig_steps = (size_t**)(shapes + ninputs + 1); + size_t** steps = orig_steps + ninputs + 1; + + char** ptrs = (char**)(steps + ninputs + 1); + + size_t* step_buf = (size_t*)(ptrs + ninputs + 1); + int* shape_buf = (int*)(step_buf + (ninputs + 1)*max_ndims); + + int* all_ndims = shape_buf + (ninputs + 1)*max_ndims; + size_t* all_type_sizes = (size_t*)(all_ndims + ninputs + 1); + + for(i = 0; i <= ninputs; i++) { + all_ndims[i] = i == 0 ? out_ndims : inp_ndims[i-1]; + all_type_sizes[i] = sizeof(T); + orig_shapes[i] = (int*)(i == 0 ? out_shape : inp_shape ? inp_shape[i-1] : 0); + orig_steps[i] = (size_t*)(i == 0 ? out_step : inp_step ? inp_step[i-1] : 0); + shapes[i] = shape_buf + max_ndims*i; + steps[i] = step_buf + max_ndims*i; + } + + if (!prepare_for_broadcast_op(ninputs + 1, max_ndims, all_type_sizes, + all_ndims, (const int**)orig_shapes, + (const size_t**)orig_steps, + shapes, steps)) + return; + + nary_forward_impl( + f, scale, ninputs, max_ndims, shapes[0], inp, out, (const size_t **) steps, ptrs); + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + if (inputs_arr.depth() == CV_16S) + { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + // TODO: assert types + typeDispatch(outputs[0].type(), inputs.size(), inputs, outputs); + } + + template + inline void opDispatch(size_t ninputs, Args&&... args) + { + switch (op) + { + case OPERATION::EQUAL: + { + auto equal = [](const T &a, const T &b) { return a == b; }; + binary_forward(equal, std::forward(args)...); + break; + } + case OPERATION::GREATER: + { + auto greater = [](const T &a, const T &b) { return a > b; }; + binary_forward(greater, std::forward(args)...); + break; + } + case OPERATION::GREATER_EQUAL: + { + auto greater_equal = [](const T &a, const T &b) { return a >= b; }; + binary_forward(greater_equal, std::forward(args)...); + break; + } + case OPERATION::LESS: + { + auto less = [](const T &a, const T &b) { return a < b; }; + binary_forward(less, std::forward(args)...); + break; + } + case OPERATION::LESS_EQUAL: + { + auto less_equal = [](const T &a, const T &b) { return a <= b; }; + binary_forward(less_equal, std::forward(args)...); + break; + } + case OPERATION::POW: + { + auto pow = [] (const T& a, const T& b) { return std::pow(a, b); }; + binary_forward(pow, std::forward(args)...); + break; + } + case OPERATION::BITSHIFT: + { + auto bitshift = [] (const uint8_t &a, const uint8_t &b) { return a << b; }; + binary_forward(bitshift, std::forward(args)...); + break; + } + case OPERATION::MAX: + { + auto max = [](const T &a, const T &b) { return std::max(a, b); }; + nary_forward(max, T{1}, std::forward(args)...); + break; + } + case OPERATION::MEAN: + { + auto mean = [](const T &a, const T &b) { return (a + b) / T{2}; }; + nary_forward(mean, T{1} / ninputs, std::forward(args)...); + break; + } + case OPERATION::MIN: + { + auto min = [](const T &a, const T &b) { return std::min(a, b); }; + nary_forward(min, T{1}, std::forward(args)...); + break; + } + case OPERATION::MOD: + { + auto mod = [](const uint8_t &a, const uint8_t &b) { return a % b; }; + binary_forward(mod, std::forward(args)...); + break; + } + case OPERATION::PROD: + { + auto prod = [](const T &a, const T &b) { return a * b; }; + binary_forward(prod, std::forward(args)...); + break; + } + case OPERATION::SUB: + { + auto sub = [](const T &a, const T &b) { return a - b; }; + binary_forward(sub, std::forward(args)...); + break; + } + case OPERATION::SUM: + { + auto sum = [](const T &a, const T &b) { return a + b; }; + nary_forward(sum, T{1}, std::forward(args)...); + break; + } + case OPERATION::ADD: + { + auto add = [](const T &a, const T &b) { return a + b; }; + binary_forward(add, std::forward(args)...); + break; + } + case OPERATION::DIV: + { + auto div = [](const T &a, const T &b) { return a / b; }; + binary_forward(div, std::forward(args)...); + break; + } + case OPERATION::AND: + { + auto op_and = [](const uint8_t &a, const uint8_t &b) { return a & b; }; + binary_forward(op_and, std::forward(args)...); + break; + } + case OPERATION::OR: + { + auto op_or = [](const uint8_t &a, const uint8_t &b) { return a | b; }; + binary_forward(op_or, std::forward(args)...); + break; + } + case OPERATION::XOR: + { + auto op_xor = [](const uint8_t &a, const uint8_t &b) { return a ^ b; }; + binary_forward(op_xor, std::forward(args)...); + break; + } + default: + CV_Error(Error::StsBadArg, "Unsupported operation."); + }; + } + + template + inline void typeDispatch(const int type, Args&&... args) + { + switch (type) + { + case CV_8U: + opDispatch(std::forward(args)...); + break; + case CV_32S: + opDispatch(std::forward(args)...); + break; + case CV_32F: + CV_Assert(op != OPERATION::BITSHIFT && op != OPERATION::MOD && + op != OPERATION::AND && op != OPERATION::OR && + op != OPERATION::XOR); + opDispatch(std::forward(args)...); + break; + default: + CV_Error(cv::Error::BadDepth, "Unsupported type."); + }; + } + + virtual bool tryQuantize(const std::vector > &scales, + const std::vector > &zeropoints, LayerParams& params) CV_OVERRIDE + { + return false; + } + + virtual int64 getFLOPS(const std::vector &inputs, + const std::vector &outputs) const CV_OVERRIDE + { + CV_Assert(inputs.size()); + return inputs.size() * total(outputs[0]); + } +}; + +Ptr NaryEltwiseLayer::create(const LayerParams& params) +{ + return Ptr(new NaryEltwiseLayerImpl(params)); +} + +} +} diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index ebbda98d51..e99fbb319e 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -63,10 +63,17 @@ class ONNXImporter LayerInfo(int _layerId = 0, int _outputId = 0) : layerId(_layerId), outputId(_outputId) {} }; + struct TensorInfo { + int real_ndims; + TensorInfo(int _real_ndims = 0) : real_ndims(_real_ndims) {} + }; + std::map getGraphTensors( const opencv_onnx::GraphProto& graph_proto); Mat getBlob(const opencv_onnx::NodeProto& node_proto, int index); Mat getBlob(const std::string& input_name); + TensorInfo getBlobExtraInfo(const opencv_onnx::NodeProto& node_proto, int index); + TensorInfo getBlobExtraInfo(const std::string& input_name); LayerParams getLayerParams(const opencv_onnx::NodeProto& node_proto); @@ -101,6 +108,7 @@ protected: std::string framework_name; std::map constBlobs; + std::map constBlobsExtraInfo; std::map outShapes; // List of internal blobs shapes. bool hasDynamicShapes; // Whether the model has inputs with dynamic shapes @@ -134,9 +142,6 @@ private: void parseReduce (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseSlice (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseSplit (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); - void parseBias (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); - void parsePow (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); - void parseMinMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseNeg (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseConstant (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseLSTM (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -148,14 +153,12 @@ private: void parseElu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseTanh (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseAbs (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); - void parseCompare (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parsePRelu (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseLRN (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseInstanceNormalization(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseBatchNormalization (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseGemm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseMatMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); - void parseMul (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseConv (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseConvTranspose (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseTranspose (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -175,6 +178,7 @@ private: void parseSoftMax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseDetectionOutput (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseCumSum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseElementWise (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseDepthToSpace (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -399,6 +403,7 @@ std::map ONNXImporter::getGraphTensors( continue; layers_weights.insert(std::make_pair(tensor_proto.name(), mat)); + constBlobsExtraInfo.insert(std::make_pair(tensor_proto.name(), TensorInfo(tensor_proto.dims_size()))); } return layers_weights; } @@ -506,6 +511,7 @@ LayerParams ONNXImporter::getLayerParams(const opencv_onnx::NodeProto& node_prot opencv_onnx::TensorProto tensor = attribute_proto.t(); Mat blob = getMatFromTensor(tensor); lp.blobs.push_back(blob); + lp.set("original_dims_of_mat", tensor.dims_size()); } else if (attribute_proto.has_g()) { @@ -573,6 +579,23 @@ Mat ONNXImporter::getBlob(const std::string& input_name) return constBlob->second; } +ONNXImporter::TensorInfo ONNXImporter::getBlobExtraInfo(const opencv_onnx::NodeProto &node_proto, int index) +{ + CV_Assert(index < node_proto.input_size()); + const std::string& input_name = node_proto.input(index); + return getBlobExtraInfo(input_name); +} + +ONNXImporter::TensorInfo ONNXImporter::getBlobExtraInfo(const std::string& input_name) +{ + std::map::const_iterator constBlobExtraInfo = constBlobsExtraInfo.find(input_name); + if (constBlobExtraInfo == constBlobsExtraInfo.end()) + { + CV_Error(Error::StsBadArg, std::string("Blob ") + input_name + " not found in const blobs of extra info"); + } + return constBlobExtraInfo->second; +} + void ONNXImporter::addLayer(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { @@ -1429,145 +1452,6 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP addLayer(layerParams, node_proto); } -void ONNXImporter::parseBias(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) -{ - opencv_onnx::NodeProto node_proto = node_proto_; - const std::string& layer_type = node_proto.op_type(); - bool isSub = layer_type == "Sub"; - - if (layer_type == "Sum" && node_proto.input_size() == 1) - { - layerParams.type = "Identity"; - addLayer(layerParams, node_proto); - return; - } - - CV_Assert((node_proto.input_size() == 2) || (layer_type == "Sum" && node_proto.input_size() > 2)); - - if (layer_type == "Sum" && node_proto.input_size() > 2) - { - for (int i = 0; i < node_proto.input_size(); ++i) - { - if (layer_id.find(node_proto.input(i)) == layer_id.end()) - { - CV_Error(Error::StsNotImplemented, "Sum of constants is not implemented for inputs > 2"); - } - } - } - - bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end(); - bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end(); - if (is_const_0 && is_const_1) - { - Mat blob_0 = getBlob(node_proto, 0); - Mat blob_1 = getBlob(node_proto, 1); - CV_Assert(blob_0.size == blob_1.size); - Mat output = isSub ? (blob_0 - blob_1) : (blob_0 + blob_1); - addConstant(node_proto.output(0), output); - return; - } - else if (is_const_0 || is_const_1) - { - int const_blob_id = is_const_0 ? 0 : 1; - int input_id = 1 - const_blob_id; - Mat blob = getBlob(node_proto, const_blob_id); - int blob_total = blob.total(); - - const float inputScale = isSub && is_const_0 ? -1.f : 1.f; - const float constScale = isSub && is_const_1 ? -1.f : 1.f; - - if (blob_total == 1) { - layerParams.type = "Power"; - layerParams.set("scale", inputScale); - layerParams.set("shift", constScale * blob.ptr()[0]); - } - else { - MatShape inpShape = outShapes[node_proto.input(input_id)]; - if (shape(blob) == inpShape) - { - LayerParams constParams; - constParams.name = layerParams.name + "/const"; - constParams.type = "Const"; - constParams.blobs.push_back(blob); - int id = dstNet.addLayer(constParams.name, constParams.type, constParams); - layer_id.insert(std::make_pair(constParams.name, LayerInfo(id, 0))); - outShapes[constParams.name] = shape(blob); - - layerParams.type = "Eltwise"; - float coeffs[] = {1., isSub ? -1.f : 1.f}; - layerParams.set("coeff", DictValue::arrayReal(coeffs, 2)); - node_proto.set_input(const_blob_id, constParams.name); - } - else - { - if (inputScale < 0.f) - { - addNegation(layerParams, node_proto, input_id); - } - - layerParams.type = "Scale"; - layerParams.set("bias_term", true); - int axis = 1; - for (int i = 0; i < graph_proto.initializer_size(); i++) - { - opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i); - if (tensor_proto.name() == node_proto.input(const_blob_id)) - { - axis = inpShape.size() - tensor_proto.dims_size(); - break; - } - } - layerParams.set("axis", axis); - blob = blob.reshape(1, 1); - layerParams.blobs.push_back(constScale * blob); - } - } - } - else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)]) - { - layerParams.type = "Eltwise"; - if (isSub) - { - static float subCoeffs[] = {1.f, -1.f}; - layerParams.set("coeff", DictValue::arrayReal(subCoeffs, 2)); - } - } - else - { - if (isSub) - { - addNegation(layerParams, node_proto, 1); - } - layerParams.type = "Scale"; - layerParams.set("bias_term", true); - } - addLayer(layerParams, node_proto); -} - -void ONNXImporter::parsePow(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) -{ - if (layer_id.find(node_proto.input(1)) != layer_id.end()) - CV_Error(Error::StsNotImplemented, "Unsupported Pow op with variable power"); - - Mat blob = getBlob(node_proto, 1); - if (blob.total() != 1) - CV_Error(Error::StsNotImplemented, "Pow op supports only scalar power"); - - blob.convertTo(blob, CV_32F); - layerParams.type = "Power"; - layerParams.set("power", blob.ptr()[0]); - addLayer(layerParams, node_proto); -} - -// "Min" "Max" -void ONNXImporter::parseMinMax(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) -{ - const std::string& layer_type = node_proto.op_type(); - layerParams.type = "Eltwise"; - layerParams.set("operation", layer_type == "Max" ? "max" : "min"); - addLayer(layerParams, node_proto); -} - void ONNXImporter::parseNeg(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { layerParams.type = "Power"; @@ -1580,6 +1464,12 @@ void ONNXImporter::parseConstant(LayerParams& layerParams, const opencv_onnx::No CV_Assert(node_proto.input_size() == 0); CV_Assert(layerParams.blobs.size() == 1); addConstant(node_proto.output(0), layerParams.blobs[0]); + // add constant for constBlobsExtraInfo + if (layerParams.has("original_dims_of_mat")) + { + int original_dims_of_mat = layerParams.get("original_dims_of_mat"); + constBlobsExtraInfo.insert(std::make_pair(node_proto.output(0), TensorInfo(original_dims_of_mat))); + } } void transformBlobs(std::vector& blobs) @@ -1988,32 +1878,6 @@ void ONNXImporter::parseAbs(LayerParams& layerParams, const opencv_onnx::NodePro addLayer(layerParams, node_proto); } -void ONNXImporter::parseCompare(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) -{ - CV_Assert(node_proto.input_size() == 2); - const std::string& layer_type = node_proto.op_type(); - - bool is_const_0 = layer_id.find(node_proto.input(0)) == layer_id.end(); - bool is_const_1 = layer_id.find(node_proto.input(1)) == layer_id.end(); - - if (is_const_0 || is_const_1) - { - Mat blob = getBlob(node_proto, static_cast(is_const_1)); - blob = blob.reshape(1, 1); - layerParams.blobs.push_back(blob); - } - - layerParams.type = "Compare"; - - if (layer_type == "Equal") - layerParams.set("mode", "equal"); - else if (layer_type == "Greater") - layerParams.set("mode", "greater"); - else - layerParams.set("mode", "less"); - addLayer(layerParams, node_proto); -} - void ONNXImporter::parsePRelu(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { layerParams.type = "PReLU"; @@ -2189,169 +2053,6 @@ void findBroadAxis(const MatShape& broadShape, const MatShape& outShape, size_t& axis += diff; } -// "Mul" "Div" -void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) -{ - opencv_onnx::NodeProto node_proto = node_proto_; - const std::string& layer_type = node_proto.op_type(); - const std::string output_name = node_proto.output(0); - CV_Assert(node_proto.input_size() == 2); - - bool isDiv = layer_type == "Div"; - int constId = -1; - bool haveVariables = false; - for (int i = 0; i < 2; ++i) - { - if (constBlobs.find(node_proto.input(i)) != constBlobs.end()) - constId = i; - else - haveVariables = true; - } - if (constId != -1 && haveVariables) - { - Mat blob = getBlob(node_proto, constId); - blob = blob.reshape(1, 1); - if (blob.total() == 1) { - float blob_value = blob.ptr()[0]; - float coeff = blob_value; - if (isDiv) - { - coeff = 1.f / blob_value; - if (constId == 0) - { - // Power layer calculates (x*scale + shift)^power, so const/x -> (x * (1/const) + 0)^(-1) - layerParams.set("power", -1.f); - } - } - layerParams.set("scale", coeff); - layerParams.type = "Power"; - } - else { - if (isDiv) - divide(1.0, blob, blob); - layerParams.blobs.push_back(blob); - layerParams.type = "Scale"; - } - } - else if (!haveVariables) - { - Mat inp0 = getBlob(node_proto, 0); - Mat inp1 = getBlob(node_proto, 1); - - if (inp0.size != inp1.size && (inp0.total() != 1 || inp1.total() != 1)) - CV_Error_(Error::StsNotImplemented, ("Different shapes case is not supported with constant inputs: %s", layer_type.c_str())); - - if (inp0.total() == 1 && inp1.total() == 1 && inp0.dims != inp1.dims) - { - if (inp0.dims < inp1.dims) - { - inp0 = inp0.reshape(1, inp1.dims, inp1.size); - inp0.dims = inp1.dims; - } - else - { - inp1 = inp1.reshape(1, inp0.dims, inp0.size); - inp1.dims = inp0.dims; - } - } - - Mat out; - if (inp0.total() != inp1.total()) - { - if (inp0.total() == 1) - { - float inp0_value = inp0.ptr()[0]; - float coeff = isDiv ? 1.0 / inp0_value : inp0_value; - multiply(inp1, coeff, out); - } - else - { - float inp1_value = inp1.ptr()[0]; - float coeff = isDiv ? 1.0 / inp1_value : inp1_value; - multiply(inp0, coeff, out); - } - - } - else - { - out = isDiv ? inp0 / inp1 : inp0.mul(inp1); - } - - if (inp0.dims == 1 && inp1.dims == 1) - out.dims = 1; // to workaround dims == 1 - addConstant(output_name, out); - return; - } - else if (outShapes[node_proto.input(0)] == outShapes[node_proto.input(1)]) - { - layerParams.type = "Eltwise"; - layerParams.set("operation", isDiv ? "div" : "prod"); - } - else - { - // Scale layer allocate output with the first input shape - if (total(outShapes[node_proto.input(0)]) < total(outShapes[node_proto.input(1)])) - { - opencv_onnx::NodeProto proto; - proto.add_input(node_proto.input(1)); - proto.add_input(node_proto.input(0)); - proto.add_output(output_name); - node_proto = proto; - } - - if (isDiv) - { - LayerParams powerParams; - powerParams.name = layerParams.name + "/inv"; - powerParams.type = "Power"; - powerParams.set("power", -1); - - //Create Power layer - int id = dstNet.addLayer(powerParams.name, powerParams.type, powerParams); - //Connect to input - IterLayerId_t layerId = layer_id.find(node_proto.input(1)); - CV_Assert(layerId != layer_id.end()); - dstNet.connect(layerId->second.layerId, layerId->second.outputId, id, 0); - //Add shape - layer_id.insert(std::make_pair(powerParams.name, LayerInfo(id, 0))); - outShapes[powerParams.name] = outShapes[node_proto.input(1)]; - - //Replace input to Power - node_proto.set_input(1, powerParams.name); - } - - const MatShape& broadShape = outShapes[node_proto.input(1)]; - const MatShape& outShape = outShapes[node_proto.input(0)]; - - size_t axis = 0; - int broadAxis = -1; - findBroadAxis(broadShape, outShape, axis, broadAxis); - - // if there is a one dimension in the middle that should be broadcasted, broadcast it - if (broadAxis != -1) - { - opencv_onnx::NodeProto concat_node_proto = node_proto; - const std::string& input1 = concat_node_proto.input(1); - - expandMid(layerParams.name, concat_node_proto, input1, outShape[broadAxis]); - - LayerParams concatLP; - concatLP.name = layerParams.name + "/concat"; - concatLP.set("axis", broadAxis); - concatLP.type = "Concat"; - concat_node_proto.set_output(0, concatLP.name); - - addLayer(concatLP, concat_node_proto); - node_proto.set_input(1, concatLP.name); - } - - CV_Assert(axis != outShape.size()); - layerParams.set("axis", static_cast(axis)); - layerParams.type = "Scale"; - } - addLayer(layerParams, node_proto); -} - void ONNXImporter::parseConv(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) { opencv_onnx::NodeProto node_proto = node_proto_; @@ -2542,6 +2243,10 @@ void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::Nod if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) { Mat input = getBlob(node_proto, 0); + if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end()) + { + constBlobsExtraInfo.insert(std::make_pair(node_proto.output(0), getBlobExtraInfo(node_proto, 0))); + } int axis = normalize_axis(axis_, input.dims); int out_size[2] = {1, 1}; @@ -2614,12 +2319,16 @@ void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::N { // Constant input. Mat input = getBlob(node_proto, 0); + int input_dims = input.dims; + if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end()) + if (getBlobExtraInfo(node_proto, 0).real_ndims == 1) + input_dims = 1; std::vector dims; - for (int j = 0; j < input.dims; j++) { + for (int j = 0; j < input_dims; j++) { dims.push_back(input.size[j]); } - CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size()); +// CV_Assert(axes.getIntValue(axes.size()-1) <= dims.size()); for (int j = 0; j < axes.size(); j++) { const int idx = axes.getIntValue(j); CV_Assert(idx <= dims.size()); @@ -2874,6 +2583,10 @@ void ONNXImporter::parseCast(LayerParams& layerParams, const opencv_onnx::NodePr if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) { Mat blob = getBlob(node_proto, 0); + if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end()) + { + constBlobsExtraInfo.insert(std::make_pair(node_proto.output(0), getBlobExtraInfo(node_proto, 0))); + } int type; switch (layerParams.get("to")) { @@ -3011,6 +2724,10 @@ void ONNXImporter::parseConcat(LayerParams& layerParams, const opencv_onnx::Node break; } } + if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end()) + { + constBlobsExtraInfo.insert(std::make_pair(node_proto.output(0), getBlobExtraInfo(node_proto, 0))); + } if (!hasVariableInps) { @@ -3223,6 +2940,78 @@ void ONNXImporter::parseCumSum(LayerParams& layerParams, const opencv_onnx::Node addLayer(layerParams, node_proto); } +// "Equal" "Greater" "Less" "Pow" "Add" "Sub" "Mul" "Div" "Sum" "Min" "Max" +void ONNXImporter::parseElementWise(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) +{ + opencv_onnx::NodeProto node_proto = node_proto_; + String op_type = toLowerCase(node_proto.op_type()); + + layerParams.type = "NaryEltwise"; + layerParams.set("operation", toLowerCase(node_proto.op_type())); + + // element-wise layers that can have >=1 inputs but actually have one input + if (node_proto.input_size() == 1 && (op_type == "max" || op_type == "min" || op_type == "mean" || op_type == "sum")) + { + layerParams.type = "Identity"; + addLayer(layerParams, node_proto); + return; + } + + auto pre_broadcast_transform = [](Mat& t, int t_real_ndims) { + if (t.dims == 2 && t_real_ndims == 1 && t.size[1] == 1) + transpose(t, t); + }; + + size_t consts = 0; + for (size_t i = 0; i < node_proto.input_size(); ++i) + { + if (layer_id.find(node_proto.input(i)) == layer_id.end()) + { + ++consts; + } + } + + if (consts == node_proto.input_size()) + { + std::vector inputs, output; + for (size_t i = 0; i < node_proto.input_size(); ++i) + { + inputs.push_back(getBlob(node_proto, i)); + } + runLayer(layerParams, inputs, output); + CV_Assert(output.size() == 1); + addConstant(node_proto.output(0), output[0]); + return; + } + else if (consts > 0) + { + for (size_t i = 0; i < node_proto.input_size(); ++i) + { + if (layer_id.find(node_proto.input(i)) == layer_id.end()) + { + Mat inp = getBlob(node_proto, i); + // for cases like a tensor of shape (2,), it will be loaded as shape (2, 1) in OpenCV Mat, + // but for correct broadcast, we need to make it of shape (1, 2) + if (constBlobsExtraInfo.find(node_proto.input(i)) != constBlobsExtraInfo.end()) + pre_broadcast_transform(inp, getBlobExtraInfo(node_proto, i).real_ndims); + + // carry the constant by adding a Const node + LayerParams constParams; + constParams.name = node_proto.input(i); + constParams.type = "Const"; + constParams.blobs.push_back(inp); + + opencv_onnx::NodeProto proto; + proto.add_output(constParams.name); + addLayer(constParams, proto); + } + } + } + + // add element-wise layer + addLayer(layerParams, node_proto); +} + void ONNXImporter::parseDepthToSpace(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) { // We parse "DepthToSpace" and "SpaceToDepth" in this function. @@ -3794,9 +3583,6 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["ReduceL2"] = dispatch["ReduceLogSum"] = dispatch["ReduceLogSumExp"] = &ONNXImporter::parseReduce; dispatch["Slice"] = &ONNXImporter::parseSlice; dispatch["Split"] = &ONNXImporter::parseSplit; - dispatch["Add"] = dispatch["Sum"] = dispatch["Sub"] = &ONNXImporter::parseBias; - dispatch["Pow"] = &ONNXImporter::parsePow; - dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseMinMax; dispatch["Neg"] = &ONNXImporter::parseNeg; dispatch["Constant"] = &ONNXImporter::parseConstant; dispatch["LSTM"] = &ONNXImporter::parseLSTM; @@ -3808,14 +3594,12 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["Elu"] = &ONNXImporter::parseElu; dispatch["Tanh"] = &ONNXImporter::parseTanh; dispatch["Abs"] = &ONNXImporter::parseAbs; - dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = &ONNXImporter::parseCompare; dispatch["PRelu"] = &ONNXImporter::parsePRelu; dispatch["LRN"] = &ONNXImporter::parseLRN; dispatch["InstanceNormalization"] = &ONNXImporter::parseInstanceNormalization; dispatch["BatchNormalization"] = &ONNXImporter::parseBatchNormalization; dispatch["Gemm"] = &ONNXImporter::parseGemm; dispatch["MatMul"] = &ONNXImporter::parseMatMul; - dispatch["Mul"] = dispatch["Div"] = &ONNXImporter::parseMul; dispatch["Conv"] = &ONNXImporter::parseConv; dispatch["ConvTranspose"] = &ONNXImporter::parseConvTranspose; dispatch["Transpose"] = &ONNXImporter::parseTranspose; @@ -3837,6 +3621,10 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["CumSum"] = &ONNXImporter::parseCumSum; dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthToSpace; + dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] = + dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = &ONNXImporter::parseElementWise; + dispatch["Sum"] = dispatch["Min"] = dispatch["Max"] = &ONNXImporter::parseElementWise; + std::vector simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos", "Cosh", "Dropout", "Erf", "Exp", "Floor", "HardSigmoid", "HardSwish", "Identity", "Log", "Round", "Reciprocal", "Selu", "Sign", "Sigmoid", "Sin", "Sinh", "Softmax", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp index dd0965904a..292cd2a066 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_all_denylist.inl.hpp @@ -55,3 +55,5 @@ "test_sub_bcast", "test_sub_uint8", // output type mismatch "test_upsample_nearest", +"test_div_bcast", // remove when 1D Mat is supported +"test_mul_bcast", // remove when 1D Mat is supported diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_cpu_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_cpu_denylist.inl.hpp index e69de29bb2..8b13789179 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_cpu_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_cpu_denylist.inl.hpp @@ -0,0 +1 @@ +