diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 1cbc654603..46c5f338af 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1067,6 +1067,18 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams& params); }; + class CV_EXPORTS ScatterLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + + class CV_EXPORTS ScatterNDLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index 03ba8ab0e9..f169f4e6a8 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -239,7 +239,178 @@ PERF_TEST_P_(Layer_Slice, FastNeuralStyle_eccv16) test_slice<4>(inputShape, begin, end); } +struct Layer_Scatter : public TestBaseWithParam > +{ + void test_layer(const std::vector& shape, const String reduction = "none", int axis = 0) + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + Mat data(shape, CV_32FC1); + Mat indices(shape, CV_32FC1); + Mat updates(shape, CV_32FC1); + + Scalar mean = 0.f; + Scalar std = 1.f; + randn(data, mean, std); + randu(indices, 0, shape[axis]); + randn(updates, mean, std); + + indices.convertTo(indices, CV_32SC1, 1, -1); + + Net net; + LayerParams lp; + lp.type = "Scatter"; + lp.name = "testLayer"; + lp.set("reduction", reduction); + lp.set("axis", axis); + + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.connect(0, 1, id, 1); + net.connect(0, 2, id, 2); + + // warmup + { + std::vector inpNames(3); + inpNames[0] = "data"; + inpNames[1] = "indices"; + inpNames[2] = "updates"; + net.setInputsNames(inpNames); + net.setInput(data, inpNames[0]); + net.setInput(indices, inpNames[1]); + net.setInput(updates, inpNames[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 8; + int C = 256; + int H = 128; + int W = 100; +}; + +PERF_TEST_P_(Layer_Scatter, DISABLED_Scatter) +{ + test_layer({N, C, H, W}); +} + +PERF_TEST_P_(Layer_Scatter, DISABLED_Scatter_add) +{ + test_layer({N, C, H, W}, "add"); +} + +struct Layer_ScatterND : public TestBaseWithParam > +{ + void test_layer(const std::vector& shape, const String reduction = "none") + { + int backendId = get<0>(GetParam()); + int targetId = get<1>(GetParam()); + + std::vector indices_shape(shape); + indices_shape.push_back(int(shape.size())); + Mat data(shape, CV_32FC1); + Mat indices(indices_shape, CV_32FC1); + Mat updates(shape, CV_32FC1); + + Scalar mean = 0.f; + Scalar std = 1.f; + randn(data, mean, std); + randn(updates, mean, std); + + // initialize the indices with index tuples like [0...N, 0...C, 0...H, 0...W] + std::vector current_index_tuple(shape.size()); + int total = data.total(); + std::vector indices_step; + for (int i = 0; i < indices.dims; i++) + { + int step = indices.step.p[i] / sizeof(float); + indices_step.push_back(step); + } + int t, j, idx, offset_at_idx, offset; + for (int i = 0; i < total; i++) + { + t = i; + for (j = shape.size() - 1; j >= 0; j--) + { + idx = t / shape[j]; + offset_at_idx = (int)(t - idx * shape[j]); + current_index_tuple[j] = offset_at_idx; + t = idx; + } + + offset = 0; + for (j = 0; j < shape.size(); j++) + offset += current_index_tuple[j] * indices_step[j]; + + for (j = 0; j < shape.size(); j++) + indices.at(offset + j) = current_index_tuple[j]; + } + + Net net; + LayerParams lp; + lp.type = "ScatterND"; + lp.name = "testLayer"; + lp.set("reduction", reduction); + + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.connect(0, 1, id, 1); + net.connect(0, 2, id, 2); + + // warmup + { + std::vector inpNames(3); + inpNames[0] = "data"; + inpNames[1] = "indices"; + inpNames[2] = "updates"; + net.setInputsNames(inpNames); + net.setInput(data, inpNames[0]); + net.setInput(indices, inpNames[1]); + net.setInput(updates, inpNames[2]); + + net.setPreferableBackend(backendId); + net.setPreferableTarget(targetId); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + int N = 8; + int C = 256; + int H = 128; + int W = 100; +}; + +PERF_TEST_P_(Layer_ScatterND, DISABLED_ScatterND) +{ + test_layer({N, C, H ,W}); +} + +PERF_TEST_P_(Layer_ScatterND, DISABLED_ScatterND_add) +{ + test_layer({N, C, H , W}, "add"); +} + INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false)); INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); +INSTANTIATE_TEST_CASE_P(/**/, Layer_Scatter, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); +INSTANTIATE_TEST_CASE_P(/**/, Layer_ScatterND, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU))); } // namespace diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index f77523916b..27956097ac 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -181,6 +181,9 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(GRU, GRULayer); CV_DNN_REGISTER_LAYER_CLASS(CumSum, CumSumLayer); + CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer); + CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer); + CV_DNN_REGISTER_LAYER_CLASS(Quantize, QuantizeLayer); CV_DNN_REGISTER_LAYER_CLASS(Dequantize, DequantizeLayer); CV_DNN_REGISTER_LAYER_CLASS(Requantize, RequantizeLayer); diff --git a/modules/dnn/src/layers/scatterND_layer.cpp b/modules/dnn/src/layers/scatterND_layer.cpp new file mode 100644 index 0000000000..648d35fc0c --- /dev/null +++ b/modules/dnn/src/layers/scatterND_layer.cpp @@ -0,0 +1,202 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" + +#include // for std::max & std::min + +namespace cv { namespace dnn { + +class ScatterNDLayerImpl CV_FINAL : public ScatterNDLayer +{ +public: + enum class REDUCTION + { + NONE = 1, + ADD, + MUL, + MAX, + MIN + } reduction; + + ScatterNDLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + String reduction_name = toLowerCase(params.get("reduction", "none")); + if (reduction_name == "none") + reduction = REDUCTION::NONE; + else if (reduction_name == "add") + reduction = REDUCTION::ADD; + else if (reduction_name == "mul") + reduction = REDUCTION::MUL; + else if (reduction_name == "max") + reduction = REDUCTION::MAX; + else if (reduction_name == "min") + reduction = REDUCTION::MIN; + else + CV_Error(cv::Error::StsBadArg, "Unkown reduction \"" + reduction_name + "\""); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + CV_CheckEQ(inputs.size(), 3ull, "ScatterND: require three inputs."); + + size_t r = inputs[0].size(), q = inputs[1].size(), p = inputs[2].size(), k = inputs[1].back(); + CV_CheckEQ(r + q - inputs[1].back() - 1, p, "ScatterND: updates should have rank of data.dims + indices.dims - indices.size[-1] - 1"); + CV_CheckLE(k, r, "ScatterND: indices.shape[-1] must be less than (or equal to) the rank of input data."); + + for (int i = 0; i < q - 1; i++) // np.ndindex(indices.shape[-1]) + { + CV_CheckEQ(inputs[2][i], inputs[1][i], "ScatterND: updates.shape[0 : rank(indices)-1] must equal to indices.shape[0 : rank(indices)-1]."); + } + for (int i = q - 1, j = k, m = 0; i + m < p; m++) + { + CV_CheckEQ(inputs[2][i + m], inputs[0][j + m], "ScatterND: updates.shape[rank(indices)-1 : ] must equal to data[indices.shape[-1] : rank(data)-1]."); + } + + outputs.assign(1, inputs[0]); + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const Mat& data = inputs[0]; + const Mat& indices = inputs[1]; + const Mat& updates = inputs[2]; + Mat& out = outputs[0]; + + typeDispatch(outputs[0].type(), data, indices, updates, out); + } + + // NOTE: This impl does not check whether indices have duplicate entries. + // The last duplicate entry will overwrite the previous. + template + void forward_impl(const Functor& rd, const Mat& data, const Mat& indices, const Mat& updates, Mat& out) + { + data.copyTo(out); + + const int* shape = data.size.p; + const size_t* step = data.step.p; + + const int ind_ndims = indices.dims; + const int* ind_shape = indices.size.p; + const T* p_indices = indices.ptr(); + + const int upd_ndims = updates.dims; + const int* upd_shape = updates.size.p; + const T* p_updates = updates.ptr(); + + T* p_out = out.ptr(); + + int k = ind_shape[ind_ndims - 1]; // last dim of indices + size_t total = (size_t)(indices.total() / k); + + size_t updates_size = 1; + for (int i = ind_ndims - 1; i < upd_ndims; i++) + updates_size *= upd_shape[i]; + + size_t inp_start_offset = 0; + size_t ind_start_offset = 0; + size_t upd_start_offset = 0; + for (size_t i = 0; i < total; i++, ind_start_offset += k, upd_start_offset += updates_size) + { + const T* tmp_p_indices = p_indices + ind_start_offset; + inp_start_offset = 0; + for (int j = 0; j < k; j++) + { + CV_Assert(tmp_p_indices[j] < shape[j] && tmp_p_indices[j] > -shape[j]); + inp_start_offset += (((int)tmp_p_indices[j] + shape[j]) % shape[j]) * step[j]; + } + inp_start_offset /= sizeof(T); + + const T* tmp_p_updates = p_updates + upd_start_offset; + T* tmp_p_out = p_out + inp_start_offset; + for (int j = 0; j < updates_size; j++) + tmp_p_out[j] = rd(tmp_p_out[j], tmp_p_updates[j]); + } + } + + template + inline void typeDispatch(const int type, Args&&... args) + { + switch (type) + { + case CV_8U: + reductionDispatch(std::forward(args)...); + break; + case CV_32S: + reductionDispatch(std::forward(args)...); + break; + case CV_32F: + reductionDispatch(std::forward(args)...); + break; + default: + CV_Error(cv::Error::BadDepth, "Unsupported type."); + }; + } + + template + inline void reductionDispatch(Args&&... args) + { + switch (reduction) + { + case REDUCTION::NONE: + { + auto rd = [](const T& a, const T& b) { return b; }; // a from input data, b from updates + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::ADD: + { + auto rd = [](const T& a, const T& b) { return a + b; }; + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::MUL: + { + auto rd = [](const T& a, const T& b) { return a * b; }; + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::MAX: + { + auto rd = [](const T& a, const T& b) { return std::max(a, b); }; + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::MIN: + { + auto rd = [](const T& a, const T& b) { return std::min(a, b); }; + forward_impl(rd, std::forward(args)...); + break; + } + default: + CV_Error(Error::StsBadArg, "Unsupported reduction."); + }; + } +}; + +Ptr ScatterNDLayer::create(const LayerParams& params) +{ + return makePtr(params); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/layers/scatter_layer.cpp b/modules/dnn/src/layers/scatter_layer.cpp new file mode 100644 index 0000000000..084eecb03c --- /dev/null +++ b/modules/dnn/src/layers/scatter_layer.cpp @@ -0,0 +1,208 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" + +#include // for std::max & std::min + +namespace cv { namespace dnn { + +class ScatterLayerImpl CV_FINAL : public ScatterLayer +{ +public: + enum class REDUCTION + { + NONE = 1, + ADD, + MUL, + MAX, + MIN + } reduction; + + ScatterLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + axis = params.get("axis", 0); + String reduction_name = toLowerCase(params.get("reduction", "none")); + if (reduction_name == "none") + reduction = REDUCTION::NONE; + else if (reduction_name == "add") + reduction = REDUCTION::ADD; + else if (reduction_name == "mul") + reduction = REDUCTION::MUL; + else if (reduction_name == "max") + reduction = REDUCTION::MAX; + else if (reduction_name == "min") + reduction = REDUCTION::MIN; + else + CV_Error(cv::Error::StsBadArg, "Unkown reduction \"" + reduction_name + "\""); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + CV_CheckEQ(inputs.size(), 3ull, "Scatter: require three inputs."); + CV_CheckEQ(inputs[0].size(), inputs[1].size(), "Scatter: input data should have the same ndim with indices."); + CV_CheckEQ(inputs[0].size(), inputs[2].size(), "Scatter: input data should have the same ndim with updates."); + for (size_t i = 0; i < inputs[0].size(); i++) + { + CV_CheckGE(inputs[0][i], inputs[1][i], "Scatter: each dim of input data should be greater than (or equal to) indices'."); + CV_CheckEQ(inputs[1][i], inputs[2][i], "Scatter: each dim of indices should be equal to updates'."); + } + outputs.assign(1, inputs[0]); + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const Mat& data = inputs[0]; + const Mat& indices = inputs[1]; + const Mat& updates = inputs[2]; + Mat& out = outputs[0]; + + typeDispatch(outputs[0].type(), data, indices, updates, out); + } + + template + void forward_impl(const Functor& rd, const Mat& data, const Mat& indices, const Mat& updates, Mat& out) + { + data.copyTo(out); + + const int ndims = data.dims; + const int* shape = data.size.p; + const size_t* step = data.step.p; + + const int* ind_shape = indices.size.p; + const size_t* ind_step = indices.step.p; + + size_t inp_offset = 0; + size_t ind_offset = 0; + const T* p_index = indices.ptr(); + const T* p_update = updates.ptr(); + T* p_out = out.ptr(); + + size_t total = indices.total(); + + int j, offset_at_idx, index; + size_t t, idx; + for (size_t i = 0; i < total; i++) + { + t = i; + inp_offset = 0; + ind_offset = 0; + int offset_at_axis = 0; + for (j = ndims - 1; j >= 0; j--) + { + idx = t / ind_shape[j]; + offset_at_idx = (int)(t - idx * ind_shape[j]); + ind_offset += offset_at_idx * ind_step[j]; + inp_offset += offset_at_idx * step[j]; + t = idx; + if (j == axis) + { + offset_at_axis = offset_at_idx * step[j]; + } + } + ind_offset /= sizeof(T); + + // get index and overwrite current indices + const T* tmp_p_index = p_index + ind_offset; + index = (int)(*tmp_p_index); + CV_Assert(index < shape[axis] && index > -shape[axis]); + + inp_offset = inp_offset - offset_at_axis + ((index + shape[axis]) % shape[axis]) * step[axis]; + inp_offset /= sizeof(T); + + const T* tmp_p_update = p_update + ind_offset; + T* tmp_p_out = p_out + inp_offset; + *tmp_p_out = rd(*tmp_p_out, *tmp_p_update); + } + } + + template + inline void typeDispatch(const int type, Args&&... args) + { + switch (type) + { + case CV_8U: + reductionDispatch(std::forward(args)...); + break; + case CV_32S: + reductionDispatch(std::forward(args)...); + break; + case CV_32F: + reductionDispatch(std::forward(args)...); + break; + default: + CV_Error(cv::Error::BadDepth, "Unsupported type."); + }; + } + + template + inline void reductionDispatch(Args&&... args) + { + switch (reduction) + { + case REDUCTION::NONE: + { + auto rd = [](const T& a, const T& b) { return b; }; // a from input data, b from updates + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::ADD: + { + auto rd = [](const T& a, const T& b) { return a + b; }; + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::MUL: + { + auto rd = [](const T& a, const T& b) { return a * b; }; + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::MAX: + { + auto rd = [](const T& a, const T& b) { return std::max(a, b); }; + forward_impl(rd, std::forward(args)...); + break; + } + case REDUCTION::MIN: + { + auto rd = [](const T& a, const T& b) { return std::min(a, b); }; + forward_impl(rd, std::forward(args)...); + break; + } + default: + CV_Error(Error::StsBadArg, "Unsupported reduction."); + }; + } + +private: + // Attributes + int axis; +}; + +Ptr ScatterLayer::create(const LayerParams& params) +{ + return makePtr(params); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index e1792d1fde..798f439c5c 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -188,6 +188,7 @@ private: void parseElementWise (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseDepthToSpace (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseRange (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); // Domain: com.microsoft @@ -3131,6 +3132,58 @@ void ONNXImporter::parseRange(LayerParams& layerParams, const opencv_onnx::NodeP constBlobsExtraInfo.insert(std::make_pair(node_proto.output(0), TensorInfo(1))); } +void ONNXImporter::parseScatter(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +{ + CV_CheckEQ(node_proto.input_size(), 3, "Scatter: three inputs are required."); + layerParams.type = "Scatter"; + if (node_proto.op_type() == "ScatterND") + layerParams.type = "ScatterND"; + + size_t consts = 0; + for (size_t i = 0; i < node_proto.input_size(); ++i) + if (layer_id.find(node_proto.input(i)) == layer_id.end()) + ++consts; + + if (consts == node_proto.input_size()) + { + std::vector inputs, output; + for (size_t i = 0; i < node_proto.input_size(); i++) + { + Mat blob = getBlob(node_proto, i); + if (i == 1) // indices + blob.convertTo(blob, CV_32F); + inputs.push_back(blob); + } + runLayer(layerParams, inputs, output); + CV_Assert(output.size() == 1); + addConstant(node_proto.output(0), output[0]); + return; + } + else if (consts > 0) + { + for (size_t i = 0; i < node_proto.input_size(); i++) + { + if (layer_id.find(node_proto.input(i)) == layer_id.end()) + { + Mat blob = getBlob(node_proto, i); + if (i == 1) // indices, from int32/int64 to float32 + blob.convertTo(blob, CV_32F); + + LayerParams constParams; + constParams.name = node_proto.input(i); + constParams.type = "Const"; + constParams.blobs.push_back(blob); + + opencv_onnx::NodeProto proto; + proto.add_output(constParams.name); + addLayer(constParams, proto); + } + } + } + + addLayer(layerParams, node_proto); +} + void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { bool is_all_input_const = true; @@ -3785,6 +3838,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["DetectionOutput"] = &ONNXImporter::parseDetectionOutput; dispatch["CumSum"] = &ONNXImporter::parseCumSum; dispatch["SpaceToDepth"] = dispatch["DepthToSpace"] = &ONNXImporter::parseDepthToSpace; + dispatch["ScatterElements"] = dispatch["Scatter"] = dispatch["ScatterND"] = &ONNXImporter::parseScatter; dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] = dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] = diff --git a/modules/dnn/test/test_onnx_conformance.cpp b/modules/dnn/test/test_onnx_conformance.cpp index e9bc0e4187..fc766c2b81 100644 --- a/modules/dnn/test/test_onnx_conformance.cpp +++ b/modules/dnn/test/test_onnx_conformance.cpp @@ -666,11 +666,15 @@ static const TestCase testConformanceConfig[] = { {"test_scatter_elements_with_axis", 3, 1}, {"test_scatter_elements_with_duplicate_indices", 3, 1}, {"test_scatter_elements_with_negative_indices", 3, 1}, + {"test_scatter_elements_with_reduction_max", 3, 1}, + {"test_scatter_elements_with_reduction_min", 3, 1}, {"test_scatter_elements_without_axis", 3, 1}, {"test_scatter_with_axis", 3, 1}, {"test_scatter_without_axis", 3, 1}, {"test_scatternd", 3, 1}, {"test_scatternd_add", 3, 1}, + {"test_scatternd_max", 3, 1}, + {"test_scatternd_min", 3, 1}, {"test_scatternd_multiply", 3, 1}, {"test_sce_NCd1_mean_weight_negative_ii", 3, 1}, {"test_sce_NCd1_mean_weight_negative_ii_expanded", 3, 1}, diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp index c18ced0c59..4c05f10305 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__cuda_denylist.inl.hpp @@ -82,3 +82,16 @@ "test_sub_uint8", "test_tan", // FP16 only "test_upsample_nearest", +"test_scatter_elements_with_axis", +"test_scatter_elements_with_duplicate_indices", +"test_scatter_elements_with_negative_indices", +"test_scatter_elements_with_reduction_max", +"test_scatter_elements_with_reduction_min", +"test_scatter_elements_without_axis", +"test_scatter_with_axis", +"test_scatter_without_axis", +"test_scatternd", +"test_scatternd_add", +"test_scatternd_max", +"test_scatternd_min", +"test_scatternd_multiply", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp index 72900a8194..4924aaf9da 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__halide_denylist.inl.hpp @@ -95,3 +95,16 @@ "test_sub_uint8", "test_tanh", "test_upsample_nearest", +"test_scatter_elements_with_axis", +"test_scatter_elements_with_duplicate_indices", +"test_scatter_elements_with_negative_indices", +"test_scatter_elements_with_reduction_max", +"test_scatter_elements_with_reduction_min", +"test_scatter_elements_without_axis", +"test_scatter_with_axis", +"test_scatter_without_axis", +"test_scatternd", +"test_scatternd_add", +"test_scatternd_max", +"test_scatternd_min", +"test_scatternd_multiply", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp index cad914d05a..e6a35dfab9 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__openvino.inl.hpp @@ -1588,6 +1588,10 @@ CASE(test_scatter_elements_with_duplicate_indices) // no filter CASE(test_scatter_elements_with_negative_indices) // no filter +CASE(test_scatter_elements_with_reduction_max) + // no filter +CASE(test_scatter_elements_with_reduction_min) + // no filter CASE(test_scatter_elements_without_axis) // no filter CASE(test_scatter_with_axis) @@ -1598,6 +1602,10 @@ CASE(test_scatternd) // no filter CASE(test_scatternd_add) // no filter +CASE(test_scatternd_max) + // no filter +CASE(test_scatternd_min) + // no filter CASE(test_scatternd_multiply) // no filter CASE(test_sce_NCd1_mean_weight_negative_ii) diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp index 101d44cbf0..8156686428 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter__vulkan_denylist.inl.hpp @@ -63,3 +63,16 @@ "test_sub_uint8", "test_transpose_all_permutations_0", "test_upsample_nearest", +"test_scatter_elements_with_axis", +"test_scatter_elements_with_duplicate_indices", +"test_scatter_elements_with_negative_indices", +"test_scatter_elements_with_reduction_max", +"test_scatter_elements_with_reduction_min", +"test_scatter_elements_without_axis", +"test_scatter_with_axis", +"test_scatter_without_axis", +"test_scatternd", +"test_scatternd_add", +"test_scatternd_max", +"test_scatternd_min", +"test_scatternd_multiply", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp index c2425d469f..9b6b2414db 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp16_denylist.inl.hpp @@ -30,4 +30,17 @@ "test_reduce_sum_square_default_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.0183411 vs 0.004 "test_reduce_sum_square_do_not_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02 "test_reduce_sum_square_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02 -"test_reduce_sum_square_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02 \ No newline at end of file +"test_reduce_sum_square_negative_axes_keepdims_random", // Expected: (normL1) <= (l1), actual: 0.010789 vs 0.004, Expected: (normInf) <= (lInf), actual: 0.0290298 vs 0.02 +"test_scatter_elements_with_axis", +"test_scatter_elements_with_duplicate_indices", +"test_scatter_elements_with_negative_indices", +"test_scatter_elements_with_reduction_max", +"test_scatter_elements_with_reduction_min", +"test_scatter_elements_without_axis", +"test_scatter_with_axis", +"test_scatter_without_axis", +"test_scatternd", +"test_scatternd_add", +"test_scatternd_max", +"test_scatternd_min", +"test_scatternd_multiply", diff --git a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp32_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp32_denylist.inl.hpp index 9a7a21f393..7fe58a07fd 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp32_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_filter_opencv_ocl_fp32_denylist.inl.hpp @@ -1,2 +1,15 @@ "test_averagepool_3d_default", "test_maxpool_3d_default", +"test_scatter_elements_with_axis", +"test_scatter_elements_with_duplicate_indices", +"test_scatter_elements_with_negative_indices", +"test_scatter_elements_with_reduction_max", +"test_scatter_elements_with_reduction_min", +"test_scatter_elements_without_axis", +"test_scatter_with_axis", +"test_scatter_without_axis", +"test_scatternd", +"test_scatternd_add", +"test_scatternd_max", +"test_scatternd_min", +"test_scatternd_multiply", diff --git a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp index 1437e5475b..0630833b1f 100644 --- a/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp +++ b/modules/dnn/test/test_onnx_conformance_layer_parser_denylist.inl.hpp @@ -384,15 +384,6 @@ "test_roialign_aligned_true", "test_scan9_sum", "test_scan_sum", -"test_scatter_elements_with_axis", -"test_scatter_elements_with_duplicate_indices", -"test_scatter_elements_with_negative_indices", -"test_scatter_elements_without_axis", -"test_scatter_with_axis", -"test_scatter_without_axis", -"test_scatternd", -"test_scatternd_add", -"test_scatternd_multiply", "test_sce_NCd1_mean_weight_negative_ii", "test_sce_NCd1_mean_weight_negative_ii_expanded", "test_sce_NCd1_mean_weight_negative_ii_log_prob",