mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Merge pull request #23279 from fengyuentau:add_topk
dnn: add ONNX TopK #23279 Merge with https://github.com/opencv/opencv_extra/pull/1200 Partially fixes #22890 and #20258 To-do: - [x] TopK forward impl - [x] add tests - [x] support Opset 1 & 10 if possible - [ ] ~Support other backends~ (TopK has two outputs, which is not supported by other backends, such as openvino) Perf: M1 (time in millisecond) | input shape | axis | dnn | ort | | --------------- | ---- | ---- | ---- | | (1000, 100) | 0 | 1.68 | 4.07 | | (1000, 100) K5 | 0 | 1.13 | 0.12 | | (1000, 100) | 1 | 0.96 | 0.77 | | (100, 100, 100) | 0 | 10.00 | 31.13 | | (100, 100, 100) | 1 | 7.33 | 9.17 | | (100, 100, 100) | 2 | 7.52 | 9.48 | M2 (time in milisecond) | input shape | axis | dnn | ort | | --------------- | ---- | ---- | ---- | | (1000, 100) | 0 | 0.76 | 2.44 | | (1000, 100) K5 | 0 | 0.68 | 0.07 | | (1000, 100) | 1 | 0.41 | 0.50 | | (100, 100, 100) | 0 | 4.83 | 17.52| | (100, 100, 100) | 1 | 3.60 | 5.08 | | (100, 100, 100) | 2 | 3.73 | 5.10 | ONNXRuntime performance testing script: https://gist.github.com/fengyuentau/a119f94fd16721ec9974b8c7b0a45d4c ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
7cf075c392
commit
347d673a87
@ -1198,6 +1198,12 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<SpaceToDepthLayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
class CV_EXPORTS TopKLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<TopKLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
//! @}
|
||||
//! @}
|
||||
CV__DNN_INLINE_NS_END
|
||||
|
@ -1043,4 +1043,67 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_Elementwise,
|
||||
/* withWebnn= */ false,
|
||||
/* withCann= */ false));
|
||||
|
||||
struct Layer_TopK : public TestBaseWithParam<tuple<Backend, Target>> {
|
||||
void test_layer(const std::vector<int> &input_shape, const int K, const int axis) {
|
||||
int backend_id = get<0>(GetParam());
|
||||
int target_id = get<1>(GetParam());
|
||||
|
||||
Mat input_data(input_shape, CV_32F);
|
||||
randn(input_data, -1.f, 1.f);
|
||||
|
||||
Net net;
|
||||
LayerParams lp;
|
||||
lp.type = "TopK";
|
||||
lp.name = "testLayer";
|
||||
lp.set("k", K);
|
||||
lp.set("axis", axis);
|
||||
net.addLayerToPrev(lp.name, lp.type, lp);
|
||||
|
||||
// Warmup
|
||||
{
|
||||
net.setInput(input_data);
|
||||
net.setPreferableBackend(backend_id);
|
||||
net.setPreferableTarget(target_id);
|
||||
net.forward();
|
||||
}
|
||||
|
||||
TEST_CYCLE() {
|
||||
net.forward();
|
||||
}
|
||||
|
||||
SANITY_CHECK_NOTHING();
|
||||
}
|
||||
|
||||
std::vector<int> input_shape_2d{1000, 100};
|
||||
std::vector<int> input_shape_3d{100, 100, 100};
|
||||
};
|
||||
|
||||
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0) {
|
||||
test_layer(input_shape_2d, input_shape_2d[0] / 2, 0);
|
||||
}
|
||||
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0_K5) {
|
||||
test_layer(input_shape_2d, 5, 0);
|
||||
}
|
||||
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis1) {
|
||||
test_layer(input_shape_2d, input_shape_2d[1] / 2, 1);
|
||||
}
|
||||
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis0) {
|
||||
test_layer(input_shape_3d, input_shape_3d[0] / 2, 0);
|
||||
}
|
||||
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis1) {
|
||||
test_layer(input_shape_3d, input_shape_3d[1] / 2, 1);
|
||||
}
|
||||
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis2) {
|
||||
test_layer(input_shape_3d, input_shape_3d[2] / 2, 2);
|
||||
}
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_TopK,
|
||||
dnnBackendsAndTargets(/* withInferenceEngine= */ false,
|
||||
/* withHalide= */ false,
|
||||
/* withCpuOCV= */ true,
|
||||
/* withVkCom= */ false,
|
||||
/* withCUDA= */ false,
|
||||
/* withNgraph= */ false,
|
||||
/* withWebnn= */ false,
|
||||
/* withCann= */ false));
|
||||
|
||||
} // namespace
|
||||
|
@ -199,6 +199,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Tile, TileLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(TopK, TopKLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Quantize, QuantizeLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Dequantize, DequantizeLayer);
|
||||
|
228
modules/dnn/src/layers/topk_layer.cpp
Normal file
228
modules/dnn/src/layers/topk_layer.cpp
Normal file
@ -0,0 +1,228 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "layers_common.hpp"
|
||||
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
namespace {
|
||||
|
||||
template<typename T>
|
||||
class ComparatorGreater {
|
||||
public:
|
||||
ComparatorGreater(const T* data, size_t step)
|
||||
: data_(data), step_(step) {}
|
||||
|
||||
void addOffset(size_t offset) {
|
||||
data_ += offset;
|
||||
}
|
||||
|
||||
void minusOffset(size_t offset) {
|
||||
data_ -= offset;
|
||||
}
|
||||
|
||||
bool operator()(const size_t lhs_idx, const size_t rhs_idx) {
|
||||
T lhs = *(data_ + lhs_idx * step_),
|
||||
rhs = *(data_ + rhs_idx * step_);
|
||||
return (lhs > rhs || (lhs == rhs && lhs_idx < rhs_idx));
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
size_t step_;
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
class ComparatorLess {
|
||||
public:
|
||||
ComparatorLess(const T* data, size_t step)
|
||||
: data_(data), step_(step) {}
|
||||
|
||||
void addOffset(size_t offset) {
|
||||
data_ += offset;
|
||||
}
|
||||
|
||||
void minusOffset(size_t offset) {
|
||||
data_ -= offset;
|
||||
}
|
||||
|
||||
bool operator()(const size_t lhs_idx, const size_t rhs_idx) {
|
||||
T lhs = *(data_ + lhs_idx * step_),
|
||||
rhs = *(data_ + rhs_idx * step_);
|
||||
return (lhs < rhs || (lhs == rhs && lhs_idx < rhs_idx));
|
||||
}
|
||||
|
||||
private:
|
||||
const T* data_;
|
||||
size_t step_;
|
||||
};
|
||||
}
|
||||
|
||||
class TopKLayerImpl CV_FINAL : public TopKLayer
|
||||
{
|
||||
public:
|
||||
TopKLayerImpl(const LayerParams& params)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
|
||||
axis = params.get<int>("axis", -1);
|
||||
largest = params.get<int>("largest", 1) == 1;
|
||||
sorted = params.get<int>("sorted", 1) == 1;
|
||||
CV_CheckTrue(sorted, "TopK: sorted == false is not supported"); // TODO: support sorted
|
||||
|
||||
CV_CheckTrue(params.has("k"), "TopK: parameter k is required but missing");
|
||||
K = params.get<int>("k");
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
const auto &input_shape = inputs.front();
|
||||
int input_dims = input_shape.size();
|
||||
|
||||
// Check if axis is valid
|
||||
CV_CheckGE(axis, -input_dims, "TopK: axis is out of range");
|
||||
CV_CheckLT(axis, input_dims, "TopK: axis is out of range");
|
||||
// Normalize axis
|
||||
int axis_normalized = normalize_axis(axis, input_shape.size());
|
||||
|
||||
// Check if K is in range (0, input_shape[axis])
|
||||
CV_CheckGT(K, 0, "TopK: K needs to be a positive integer");
|
||||
CV_CheckLT(K, input_shape[axis_normalized], "TopK: K is out of range");
|
||||
|
||||
// Assign output shape
|
||||
auto output_shape = input_shape;
|
||||
output_shape[axis_normalized] = K;
|
||||
outputs.assign(1, output_shape);
|
||||
outputs.assign(2, output_shape); // TODO: support indices of type CV_32S on 5.x
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
|
||||
std::vector<Mat> inputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
|
||||
// Normalize axis
|
||||
auto input_shape = shape(inputs.front());
|
||||
axis = normalize_axis(axis, input_shape.size());
|
||||
}
|
||||
|
||||
template<class Comparator>
|
||||
void FindTopK(const Mat &input, Mat &output_value, Mat &output_index) {
|
||||
const auto input_shape = shape(input);
|
||||
size_t loops = std::accumulate(input_shape.begin(), input_shape.begin() + axis, 1, std::multiplies<int>());
|
||||
size_t step = std::accumulate(input_shape.begin() + axis + 1, input_shape.end(), 1, std::multiplies<int>());
|
||||
int dim_axis = input_shape[axis];
|
||||
if (loops == 1) {
|
||||
auto worker = [&](const Range &r) {
|
||||
const auto *input_ptr = input.ptr<const float>(); // TODO: support other input type
|
||||
auto *output_value_ptr = output_value.ptr<float>();
|
||||
auto *output_index_ptr = output_index.ptr<float>(); // TODO: use CV_32S on 5.x
|
||||
|
||||
Comparator cmp(input_ptr, step);
|
||||
|
||||
AutoBuffer<int> buffer_index(dim_axis);
|
||||
auto *buffer_index_ptr = buffer_index.data();
|
||||
for (int offset = r.start; offset < r.end; offset++) {
|
||||
const auto *input_offset_ptr = input_ptr + offset;
|
||||
cmp.addOffset(offset);
|
||||
|
||||
std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0);
|
||||
std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp);
|
||||
|
||||
auto *output_value_offset_ptr = output_value_ptr + offset;
|
||||
auto *output_index_offset_ptr = output_index_ptr + offset;
|
||||
for (int i = 0; i < K; i++) {
|
||||
int source_index = buffer_index_ptr[i];
|
||||
output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step);
|
||||
output_index_offset_ptr[i * step] = source_index;
|
||||
}
|
||||
cmp.minusOffset(offset);
|
||||
}
|
||||
};
|
||||
parallel_for_(Range(0, step), worker);
|
||||
} else {
|
||||
auto worker = [&](const Range &r) {
|
||||
const auto *input_ptr = input.ptr<const float>();
|
||||
auto *output_value_ptr = output_value.ptr<float>();
|
||||
auto *output_index_ptr = output_index.ptr<float>();
|
||||
|
||||
Comparator cmp(input_ptr, step);
|
||||
|
||||
AutoBuffer<int> buffer_index(dim_axis);
|
||||
auto *buffer_index_ptr = buffer_index.data();
|
||||
for (int batch_index = r.start; batch_index < r.end; batch_index++) {
|
||||
for (size_t offset = 0; offset < step; offset++) {
|
||||
const auto *input_offset_ptr = input_ptr + batch_index * dim_axis * step + offset;
|
||||
cmp.addOffset(batch_index * dim_axis * step + offset);
|
||||
|
||||
std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0);
|
||||
std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp);
|
||||
|
||||
auto *output_value_offset_ptr = output_value_ptr + batch_index * K * step + offset;
|
||||
auto *output_index_offset_ptr = output_index_ptr + batch_index * K * step + offset;
|
||||
for (int i = 0; i < K; i++) {
|
||||
int source_index = buffer_index_ptr[i];
|
||||
output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step);
|
||||
output_index_offset_ptr[i * step] = source_index;
|
||||
}
|
||||
cmp.minusOffset(batch_index * dim_axis * step + offset);
|
||||
}
|
||||
}
|
||||
};
|
||||
parallel_for_(Range(0, loops), worker);
|
||||
}
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
if (inputs_arr.depth() == CV_16F)
|
||||
{
|
||||
forward_fallback(inputs_arr, outputs_arr, internals_arr);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
const auto &input = inputs.front();
|
||||
auto &output_value = outputs.front();
|
||||
auto &output_index = outputs.back();
|
||||
|
||||
if (largest) {
|
||||
FindTopK<ComparatorGreater<float>>(input, output_value, output_index);
|
||||
} else {
|
||||
FindTopK<ComparatorLess<float>>(input, output_value, output_index);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
int axis;
|
||||
bool largest;
|
||||
bool sorted;
|
||||
|
||||
int K; // FIXIT: make it layer input once dynamic shape is supported
|
||||
};
|
||||
|
||||
Ptr<TopKLayer> TopKLayer::create(const LayerParams& params)
|
||||
{
|
||||
return makePtr<TopKLayerImpl>(params);
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
@ -194,6 +194,7 @@ private:
|
||||
void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseLayerNorm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseTopK (LayerParams& LayerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseEinsum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
@ -3121,6 +3122,21 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXImporter::parseTopK(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
// K needs to be constant in case of being input (since opset 10)
|
||||
if (node_proto.input_size() == 2) {
|
||||
bool K_const = constBlobs.find(node_proto.input(1)) != constBlobs.end();
|
||||
CV_CheckTrue(K_const, "OnnxImporter/TopK: K being non-constant is not supported");
|
||||
|
||||
Mat input_K = getBlob(node_proto, 1);
|
||||
int K = input_K.at<int>(0);
|
||||
layerParams.set("k", K);
|
||||
}
|
||||
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
bool is_all_input_const = true;
|
||||
@ -3931,6 +3947,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
||||
dispatch["Tile"] = &ONNXImporter::parseTile;
|
||||
dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm;
|
||||
dispatch["GroupNormalization"] = &ONNXImporter::parseInstanceNormalization;
|
||||
dispatch["TopK"] = &ONNXImporter::parseTopK;
|
||||
|
||||
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
|
||||
dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] =
|
||||
|
@ -3202,6 +3202,37 @@ TEST_P(Test_ONNX_layers, ClipDivSharedConstant) {
|
||||
testONNXModels("clip_div_shared_constant");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, TopK) {
|
||||
auto test = [&](const std::string &basename, double l1 = 0, double lInf = 0) {
|
||||
std::string onnxmodel = _tf("models/" + basename + ".onnx", true);
|
||||
Mat input = readTensorFromONNX(_tf("data/input_" + basename + ".pb"));
|
||||
Mat output_ref_val = readTensorFromONNX(_tf("data/output_" + basename + "_0.pb")),
|
||||
output_ref_ind = readTensorFromONNX(_tf("data/output_" + basename + "_1.pb"));
|
||||
|
||||
checkBackend(&input, &output_ref_val);
|
||||
checkBackend(&input, &output_ref_ind);
|
||||
Net net = readNetFromONNX(onnxmodel);
|
||||
net.setPreferableBackend(backend);
|
||||
net.setPreferableTarget(target);
|
||||
|
||||
net.setInput(input);
|
||||
std::vector<Mat> outputs;
|
||||
net.forward(outputs, std::vector<std::string>{"values", "indices"});
|
||||
|
||||
Mat output_res_val = outputs.front(),
|
||||
output_res_ind = outputs.back();
|
||||
output_res_ind.convertTo(output_res_ind, CV_32S); // TODO: remove this conversion on 5.x
|
||||
|
||||
normAssert(output_ref_val, output_res_val, (basename + " values").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
|
||||
normAssert(output_ref_ind, output_res_ind, (basename + " indices").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
|
||||
expectNoFallbacksFromIE(net);
|
||||
};
|
||||
|
||||
test("top_k");
|
||||
test("top_k_negative_axis");
|
||||
test("top_k_smallest");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
||||
|
||||
}} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user