diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index 2abce0c87b..d9d1833780 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1198,6 +1198,12 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams ¶ms); }; + class CV_EXPORTS TopKLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_layer.cpp b/modules/dnn/perf/perf_layer.cpp index 98adc56ffb..07d1349b2d 100644 --- a/modules/dnn/perf/perf_layer.cpp +++ b/modules/dnn/perf/perf_layer.cpp @@ -1043,4 +1043,67 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_Elementwise, /* withWebnn= */ false, /* withCann= */ false)); +struct Layer_TopK : public TestBaseWithParam> { + void test_layer(const std::vector &input_shape, const int K, const int axis) { + int backend_id = get<0>(GetParam()); + int target_id = get<1>(GetParam()); + + Mat input_data(input_shape, CV_32F); + randn(input_data, -1.f, 1.f); + + Net net; + LayerParams lp; + lp.type = "TopK"; + lp.name = "testLayer"; + lp.set("k", K); + lp.set("axis", axis); + net.addLayerToPrev(lp.name, lp.type, lp); + + // Warmup + { + net.setInput(input_data); + net.setPreferableBackend(backend_id); + net.setPreferableTarget(target_id); + net.forward(); + } + + TEST_CYCLE() { + net.forward(); + } + + SANITY_CHECK_NOTHING(); + } + + std::vector input_shape_2d{1000, 100}; + std::vector input_shape_3d{100, 100, 100}; +}; + +PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0) { + test_layer(input_shape_2d, input_shape_2d[0] / 2, 0); +} +PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0_K5) { + test_layer(input_shape_2d, 5, 0); +} +PERF_TEST_P_(Layer_TopK, TopK_2D_Axis1) { + test_layer(input_shape_2d, input_shape_2d[1] / 2, 1); +} +PERF_TEST_P_(Layer_TopK, TopK_3D_Axis0) { + test_layer(input_shape_3d, input_shape_3d[0] / 2, 0); +} +PERF_TEST_P_(Layer_TopK, TopK_3D_Axis1) { + test_layer(input_shape_3d, input_shape_3d[1] / 2, 1); +} +PERF_TEST_P_(Layer_TopK, TopK_3D_Axis2) { + test_layer(input_shape_3d, input_shape_3d[2] / 2, 2); +} +INSTANTIATE_TEST_CASE_P(/**/, Layer_TopK, + dnnBackendsAndTargets(/* withInferenceEngine= */ false, + /* withHalide= */ false, + /* withCpuOCV= */ true, + /* withVkCom= */ false, + /* withCUDA= */ false, + /* withNgraph= */ false, + /* withWebnn= */ false, + /* withCann= */ false)); + } // namespace diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index ce1eb77649..61db2c1ba6 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -199,6 +199,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer); CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer); CV_DNN_REGISTER_LAYER_CLASS(Tile, TileLayer); + CV_DNN_REGISTER_LAYER_CLASS(TopK, TopKLayer); CV_DNN_REGISTER_LAYER_CLASS(Quantize, QuantizeLayer); CV_DNN_REGISTER_LAYER_CLASS(Dequantize, DequantizeLayer); diff --git a/modules/dnn/src/layers/topk_layer.cpp b/modules/dnn/src/layers/topk_layer.cpp new file mode 100644 index 0000000000..06b3ebdc37 --- /dev/null +++ b/modules/dnn/src/layers/topk_layer.cpp @@ -0,0 +1,228 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" + +#include + +namespace cv { namespace dnn { + +namespace { + +template +class ComparatorGreater { +public: + ComparatorGreater(const T* data, size_t step) + : data_(data), step_(step) {} + + void addOffset(size_t offset) { + data_ += offset; + } + + void minusOffset(size_t offset) { + data_ -= offset; + } + + bool operator()(const size_t lhs_idx, const size_t rhs_idx) { + T lhs = *(data_ + lhs_idx * step_), + rhs = *(data_ + rhs_idx * step_); + return (lhs > rhs || (lhs == rhs && lhs_idx < rhs_idx)); + } + +private: + const T* data_; + size_t step_; +}; + +template +class ComparatorLess { +public: + ComparatorLess(const T* data, size_t step) + : data_(data), step_(step) {} + + void addOffset(size_t offset) { + data_ += offset; + } + + void minusOffset(size_t offset) { + data_ -= offset; + } + + bool operator()(const size_t lhs_idx, const size_t rhs_idx) { + T lhs = *(data_ + lhs_idx * step_), + rhs = *(data_ + rhs_idx * step_); + return (lhs < rhs || (lhs == rhs && lhs_idx < rhs_idx)); + } + +private: + const T* data_; + size_t step_; +}; +} + +class TopKLayerImpl CV_FINAL : public TopKLayer +{ +public: + TopKLayerImpl(const LayerParams& params) + { + setParamsFrom(params); + + axis = params.get("axis", -1); + largest = params.get("largest", 1) == 1; + sorted = params.get("sorted", 1) == 1; + CV_CheckTrue(sorted, "TopK: sorted == false is not supported"); // TODO: support sorted + + CV_CheckTrue(params.has("k"), "TopK: parameter k is required but missing"); + K = params.get("k"); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE + { + return backendId == DNN_BACKEND_OPENCV; + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + const auto &input_shape = inputs.front(); + int input_dims = input_shape.size(); + + // Check if axis is valid + CV_CheckGE(axis, -input_dims, "TopK: axis is out of range"); + CV_CheckLT(axis, input_dims, "TopK: axis is out of range"); + // Normalize axis + int axis_normalized = normalize_axis(axis, input_shape.size()); + + // Check if K is in range (0, input_shape[axis]) + CV_CheckGT(K, 0, "TopK: K needs to be a positive integer"); + CV_CheckLT(K, input_shape[axis_normalized], "TopK: K is out of range"); + + // Assign output shape + auto output_shape = input_shape; + output_shape[axis_normalized] = K; + outputs.assign(1, output_shape); + outputs.assign(2, output_shape); // TODO: support indices of type CV_32S on 5.x + + return false; + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + std::vector inputs; + inputs_arr.getMatVector(inputs); + + // Normalize axis + auto input_shape = shape(inputs.front()); + axis = normalize_axis(axis, input_shape.size()); + } + + template + void FindTopK(const Mat &input, Mat &output_value, Mat &output_index) { + const auto input_shape = shape(input); + size_t loops = std::accumulate(input_shape.begin(), input_shape.begin() + axis, 1, std::multiplies()); + size_t step = std::accumulate(input_shape.begin() + axis + 1, input_shape.end(), 1, std::multiplies()); + int dim_axis = input_shape[axis]; + if (loops == 1) { + auto worker = [&](const Range &r) { + const auto *input_ptr = input.ptr(); // TODO: support other input type + auto *output_value_ptr = output_value.ptr(); + auto *output_index_ptr = output_index.ptr(); // TODO: use CV_32S on 5.x + + Comparator cmp(input_ptr, step); + + AutoBuffer buffer_index(dim_axis); + auto *buffer_index_ptr = buffer_index.data(); + for (int offset = r.start; offset < r.end; offset++) { + const auto *input_offset_ptr = input_ptr + offset; + cmp.addOffset(offset); + + std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0); + std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp); + + auto *output_value_offset_ptr = output_value_ptr + offset; + auto *output_index_offset_ptr = output_index_ptr + offset; + for (int i = 0; i < K; i++) { + int source_index = buffer_index_ptr[i]; + output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step); + output_index_offset_ptr[i * step] = source_index; + } + cmp.minusOffset(offset); + } + }; + parallel_for_(Range(0, step), worker); + } else { + auto worker = [&](const Range &r) { + const auto *input_ptr = input.ptr(); + auto *output_value_ptr = output_value.ptr(); + auto *output_index_ptr = output_index.ptr(); + + Comparator cmp(input_ptr, step); + + AutoBuffer buffer_index(dim_axis); + auto *buffer_index_ptr = buffer_index.data(); + for (int batch_index = r.start; batch_index < r.end; batch_index++) { + for (size_t offset = 0; offset < step; offset++) { + const auto *input_offset_ptr = input_ptr + batch_index * dim_axis * step + offset; + cmp.addOffset(batch_index * dim_axis * step + offset); + + std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0); + std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp); + + auto *output_value_offset_ptr = output_value_ptr + batch_index * K * step + offset; + auto *output_index_offset_ptr = output_index_ptr + batch_index * K * step + offset; + for (int i = 0; i < K; i++) { + int source_index = buffer_index_ptr[i]; + output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step); + output_index_offset_ptr[i * step] = source_index; + } + cmp.minusOffset(batch_index * dim_axis * step + offset); + } + } + }; + parallel_for_(Range(0, loops), worker); + } + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + if (inputs_arr.depth() == CV_16F) + { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const auto &input = inputs.front(); + auto &output_value = outputs.front(); + auto &output_index = outputs.back(); + + if (largest) { + FindTopK>(input, output_value, output_index); + } else { + FindTopK>(input, output_value, output_index); + } + } + +private: + int axis; + bool largest; + bool sorted; + + int K; // FIXIT: make it layer input once dynamic shape is supported +}; + +Ptr TopKLayer::create(const LayerParams& params) +{ + return makePtr(params); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index e91e2605c5..5b53442850 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -194,6 +194,7 @@ private: void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseLayerNorm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); + void parseTopK (LayerParams& LayerParams, const opencv_onnx::NodeProto& node_proto); void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); void parseEinsum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto); @@ -3121,6 +3122,21 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N } } +void ONNXImporter::parseTopK(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +{ + // K needs to be constant in case of being input (since opset 10) + if (node_proto.input_size() == 2) { + bool K_const = constBlobs.find(node_proto.input(1)) != constBlobs.end(); + CV_CheckTrue(K_const, "OnnxImporter/TopK: K being non-constant is not supported"); + + Mat input_K = getBlob(node_proto, 1); + int K = input_K.at(0); + layerParams.set("k", K); + } + + addLayer(layerParams, node_proto); +} + void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { bool is_all_input_const = true; @@ -3931,6 +3947,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version) dispatch["Tile"] = &ONNXImporter::parseTile; dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm; dispatch["GroupNormalization"] = &ONNXImporter::parseInstanceNormalization; + dispatch["TopK"] = &ONNXImporter::parseTopK; dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] = dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] = diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index f8187e43fb..8cfc469a7e 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -3202,6 +3202,37 @@ TEST_P(Test_ONNX_layers, ClipDivSharedConstant) { testONNXModels("clip_div_shared_constant"); } +TEST_P(Test_ONNX_layers, TopK) { + auto test = [&](const std::string &basename, double l1 = 0, double lInf = 0) { + std::string onnxmodel = _tf("models/" + basename + ".onnx", true); + Mat input = readTensorFromONNX(_tf("data/input_" + basename + ".pb")); + Mat output_ref_val = readTensorFromONNX(_tf("data/output_" + basename + "_0.pb")), + output_ref_ind = readTensorFromONNX(_tf("data/output_" + basename + "_1.pb")); + + checkBackend(&input, &output_ref_val); + checkBackend(&input, &output_ref_ind); + Net net = readNetFromONNX(onnxmodel); + net.setPreferableBackend(backend); + net.setPreferableTarget(target); + + net.setInput(input); + std::vector outputs; + net.forward(outputs, std::vector{"values", "indices"}); + + Mat output_res_val = outputs.front(), + output_res_ind = outputs.back(); + output_res_ind.convertTo(output_res_ind, CV_32S); // TODO: remove this conversion on 5.x + + normAssert(output_ref_val, output_res_val, (basename + " values").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf); + normAssert(output_ref_ind, output_res_ind, (basename + " indices").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf); + expectNoFallbacksFromIE(net); + }; + + test("top_k"); + test("top_k_negative_axis"); + test("top_k_smallest"); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace