Merge pull request #23279 from fengyuentau:add_topk

dnn: add ONNX TopK #23279

Merge with https://github.com/opencv/opencv_extra/pull/1200

Partially fixes #22890 and #20258

To-do:

- [x] TopK forward impl
- [x] add tests
- [x] support Opset 1 & 10 if possible
- [ ] ~Support other backends~ (TopK has two outputs, which is not supported by other backends, such as openvino)


Perf:

M1 (time in millisecond)

| input shape     | axis | dnn  | ort  |
| --------------- | ---- | ---- | ---- |
| (1000, 100)     | 0    | 1.68 | 4.07 |
| (1000, 100) K5  | 0    | 1.13 | 0.12 |
| (1000, 100)     | 1    | 0.96 | 0.77 |
| (100, 100, 100) | 0    | 10.00 | 31.13 |
| (100, 100, 100) | 1    | 7.33 | 9.17 |
| (100, 100, 100) | 2    | 7.52 | 9.48 |

M2 (time in milisecond)

| input shape     | axis | dnn  | ort  |
| --------------- | ---- | ---- | ---- |
| (1000, 100)     | 0    | 0.76 | 2.44 |
| (1000, 100) K5 | 0 | 0.68 | 0.07 |
| (1000, 100)     | 1    | 0.41 | 0.50 |
| (100, 100, 100) | 0    | 4.83 | 17.52|
| (100, 100, 100) | 1    | 3.60 | 5.08 |
| (100, 100, 100) | 2    | 3.73 | 5.10 |

ONNXRuntime performance testing script: https://gist.github.com/fengyuentau/a119f94fd16721ec9974b8c7b0a45d4c

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Yuantao Feng 2024-08-21 22:03:24 +08:00 committed by GitHub
parent 7cf075c392
commit 347d673a87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 346 additions and 0 deletions

View File

@ -1198,6 +1198,12 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<SpaceToDepthLayer> create(const LayerParams &params);
};
class CV_EXPORTS TopKLayer : public Layer
{
public:
static Ptr<TopKLayer> create(const LayerParams& params);
};
//! @}
//! @}
CV__DNN_INLINE_NS_END

View File

@ -1043,4 +1043,67 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_Elementwise,
/* withWebnn= */ false,
/* withCann= */ false));
struct Layer_TopK : public TestBaseWithParam<tuple<Backend, Target>> {
void test_layer(const std::vector<int> &input_shape, const int K, const int axis) {
int backend_id = get<0>(GetParam());
int target_id = get<1>(GetParam());
Mat input_data(input_shape, CV_32F);
randn(input_data, -1.f, 1.f);
Net net;
LayerParams lp;
lp.type = "TopK";
lp.name = "testLayer";
lp.set("k", K);
lp.set("axis", axis);
net.addLayerToPrev(lp.name, lp.type, lp);
// Warmup
{
net.setInput(input_data);
net.setPreferableBackend(backend_id);
net.setPreferableTarget(target_id);
net.forward();
}
TEST_CYCLE() {
net.forward();
}
SANITY_CHECK_NOTHING();
}
std::vector<int> input_shape_2d{1000, 100};
std::vector<int> input_shape_3d{100, 100, 100};
};
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0) {
test_layer(input_shape_2d, input_shape_2d[0] / 2, 0);
}
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis0_K5) {
test_layer(input_shape_2d, 5, 0);
}
PERF_TEST_P_(Layer_TopK, TopK_2D_Axis1) {
test_layer(input_shape_2d, input_shape_2d[1] / 2, 1);
}
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis0) {
test_layer(input_shape_3d, input_shape_3d[0] / 2, 0);
}
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis1) {
test_layer(input_shape_3d, input_shape_3d[1] / 2, 1);
}
PERF_TEST_P_(Layer_TopK, TopK_3D_Axis2) {
test_layer(input_shape_3d, input_shape_3d[2] / 2, 2);
}
INSTANTIATE_TEST_CASE_P(/**/, Layer_TopK,
dnnBackendsAndTargets(/* withInferenceEngine= */ false,
/* withHalide= */ false,
/* withCpuOCV= */ true,
/* withVkCom= */ false,
/* withCUDA= */ false,
/* withNgraph= */ false,
/* withWebnn= */ false,
/* withCann= */ false));
} // namespace

View File

@ -199,6 +199,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer);
CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer);
CV_DNN_REGISTER_LAYER_CLASS(Tile, TileLayer);
CV_DNN_REGISTER_LAYER_CLASS(TopK, TopKLayer);
CV_DNN_REGISTER_LAYER_CLASS(Quantize, QuantizeLayer);
CV_DNN_REGISTER_LAYER_CLASS(Dequantize, DequantizeLayer);

View File

@ -0,0 +1,228 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include "layers_common.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv { namespace dnn {
namespace {
template<typename T>
class ComparatorGreater {
public:
ComparatorGreater(const T* data, size_t step)
: data_(data), step_(step) {}
void addOffset(size_t offset) {
data_ += offset;
}
void minusOffset(size_t offset) {
data_ -= offset;
}
bool operator()(const size_t lhs_idx, const size_t rhs_idx) {
T lhs = *(data_ + lhs_idx * step_),
rhs = *(data_ + rhs_idx * step_);
return (lhs > rhs || (lhs == rhs && lhs_idx < rhs_idx));
}
private:
const T* data_;
size_t step_;
};
template<typename T>
class ComparatorLess {
public:
ComparatorLess(const T* data, size_t step)
: data_(data), step_(step) {}
void addOffset(size_t offset) {
data_ += offset;
}
void minusOffset(size_t offset) {
data_ -= offset;
}
bool operator()(const size_t lhs_idx, const size_t rhs_idx) {
T lhs = *(data_ + lhs_idx * step_),
rhs = *(data_ + rhs_idx * step_);
return (lhs < rhs || (lhs == rhs && lhs_idx < rhs_idx));
}
private:
const T* data_;
size_t step_;
};
}
class TopKLayerImpl CV_FINAL : public TopKLayer
{
public:
TopKLayerImpl(const LayerParams& params)
{
setParamsFrom(params);
axis = params.get<int>("axis", -1);
largest = params.get<int>("largest", 1) == 1;
sorted = params.get<int>("sorted", 1) == 1;
CV_CheckTrue(sorted, "TopK: sorted == false is not supported"); // TODO: support sorted
CV_CheckTrue(params.has("k"), "TopK: parameter k is required but missing");
K = params.get<int>("k");
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV;
}
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE
{
const auto &input_shape = inputs.front();
int input_dims = input_shape.size();
// Check if axis is valid
CV_CheckGE(axis, -input_dims, "TopK: axis is out of range");
CV_CheckLT(axis, input_dims, "TopK: axis is out of range");
// Normalize axis
int axis_normalized = normalize_axis(axis, input_shape.size());
// Check if K is in range (0, input_shape[axis])
CV_CheckGT(K, 0, "TopK: K needs to be a positive integer");
CV_CheckLT(K, input_shape[axis_normalized], "TopK: K is out of range");
// Assign output shape
auto output_shape = input_shape;
output_shape[axis_normalized] = K;
outputs.assign(1, output_shape);
outputs.assign(2, output_shape); // TODO: support indices of type CV_32S on 5.x
return false;
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
std::vector<Mat> inputs;
inputs_arr.getMatVector(inputs);
// Normalize axis
auto input_shape = shape(inputs.front());
axis = normalize_axis(axis, input_shape.size());
}
template<class Comparator>
void FindTopK(const Mat &input, Mat &output_value, Mat &output_index) {
const auto input_shape = shape(input);
size_t loops = std::accumulate(input_shape.begin(), input_shape.begin() + axis, 1, std::multiplies<int>());
size_t step = std::accumulate(input_shape.begin() + axis + 1, input_shape.end(), 1, std::multiplies<int>());
int dim_axis = input_shape[axis];
if (loops == 1) {
auto worker = [&](const Range &r) {
const auto *input_ptr = input.ptr<const float>(); // TODO: support other input type
auto *output_value_ptr = output_value.ptr<float>();
auto *output_index_ptr = output_index.ptr<float>(); // TODO: use CV_32S on 5.x
Comparator cmp(input_ptr, step);
AutoBuffer<int> buffer_index(dim_axis);
auto *buffer_index_ptr = buffer_index.data();
for (int offset = r.start; offset < r.end; offset++) {
const auto *input_offset_ptr = input_ptr + offset;
cmp.addOffset(offset);
std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0);
std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp);
auto *output_value_offset_ptr = output_value_ptr + offset;
auto *output_index_offset_ptr = output_index_ptr + offset;
for (int i = 0; i < K; i++) {
int source_index = buffer_index_ptr[i];
output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step);
output_index_offset_ptr[i * step] = source_index;
}
cmp.minusOffset(offset);
}
};
parallel_for_(Range(0, step), worker);
} else {
auto worker = [&](const Range &r) {
const auto *input_ptr = input.ptr<const float>();
auto *output_value_ptr = output_value.ptr<float>();
auto *output_index_ptr = output_index.ptr<float>();
Comparator cmp(input_ptr, step);
AutoBuffer<int> buffer_index(dim_axis);
auto *buffer_index_ptr = buffer_index.data();
for (int batch_index = r.start; batch_index < r.end; batch_index++) {
for (size_t offset = 0; offset < step; offset++) {
const auto *input_offset_ptr = input_ptr + batch_index * dim_axis * step + offset;
cmp.addOffset(batch_index * dim_axis * step + offset);
std::iota(buffer_index_ptr, buffer_index_ptr + dim_axis, 0);
std::stable_sort(buffer_index_ptr, buffer_index_ptr + dim_axis, cmp);
auto *output_value_offset_ptr = output_value_ptr + batch_index * K * step + offset;
auto *output_index_offset_ptr = output_index_ptr + batch_index * K * step + offset;
for (int i = 0; i < K; i++) {
int source_index = buffer_index_ptr[i];
output_value_offset_ptr[i * step] = *(input_offset_ptr + source_index * step);
output_index_offset_ptr[i * step] = source_index;
}
cmp.minusOffset(batch_index * dim_axis * step + offset);
}
}
};
parallel_for_(Range(0, loops), worker);
}
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
if (inputs_arr.depth() == CV_16F)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
const auto &input = inputs.front();
auto &output_value = outputs.front();
auto &output_index = outputs.back();
if (largest) {
FindTopK<ComparatorGreater<float>>(input, output_value, output_index);
} else {
FindTopK<ComparatorLess<float>>(input, output_value, output_index);
}
}
private:
int axis;
bool largest;
bool sorted;
int K; // FIXIT: make it layer input once dynamic shape is supported
};
Ptr<TopKLayer> TopKLayer::create(const LayerParams& params)
{
return makePtr<TopKLayerImpl>(params);
}
}} // namespace cv::dnn

View File

@ -194,6 +194,7 @@ private:
void parseScatter (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTile (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseLayerNorm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseTopK (LayerParams& LayerParams, const opencv_onnx::NodeProto& node_proto);
void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseEinsum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@ -3121,6 +3122,21 @@ void ONNXImporter::parseLayerNorm(LayerParams& layerParams, const opencv_onnx::N
}
}
void ONNXImporter::parseTopK(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
// K needs to be constant in case of being input (since opset 10)
if (node_proto.input_size() == 2) {
bool K_const = constBlobs.find(node_proto.input(1)) != constBlobs.end();
CV_CheckTrue(K_const, "OnnxImporter/TopK: K being non-constant is not supported");
Mat input_K = getBlob(node_proto, 1);
int K = input_K.at<int>(0);
layerParams.set("k", K);
}
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseSimpleLayers(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
bool is_all_input_const = true;
@ -3931,6 +3947,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
dispatch["Tile"] = &ONNXImporter::parseTile;
dispatch["LayerNormalization"] = &ONNXImporter::parseLayerNorm;
dispatch["GroupNormalization"] = &ONNXImporter::parseInstanceNormalization;
dispatch["TopK"] = &ONNXImporter::parseTopK;
dispatch["Equal"] = dispatch["Greater"] = dispatch["Less"] = dispatch["Pow"] = dispatch["Add"] =
dispatch["Sub"] = dispatch["Mul"] = dispatch["Div"] = dispatch["GreaterOrEqual"] =

View File

@ -3202,6 +3202,37 @@ TEST_P(Test_ONNX_layers, ClipDivSharedConstant) {
testONNXModels("clip_div_shared_constant");
}
TEST_P(Test_ONNX_layers, TopK) {
auto test = [&](const std::string &basename, double l1 = 0, double lInf = 0) {
std::string onnxmodel = _tf("models/" + basename + ".onnx", true);
Mat input = readTensorFromONNX(_tf("data/input_" + basename + ".pb"));
Mat output_ref_val = readTensorFromONNX(_tf("data/output_" + basename + "_0.pb")),
output_ref_ind = readTensorFromONNX(_tf("data/output_" + basename + "_1.pb"));
checkBackend(&input, &output_ref_val);
checkBackend(&input, &output_ref_ind);
Net net = readNetFromONNX(onnxmodel);
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
net.setInput(input);
std::vector<Mat> outputs;
net.forward(outputs, std::vector<std::string>{"values", "indices"});
Mat output_res_val = outputs.front(),
output_res_ind = outputs.back();
output_res_ind.convertTo(output_res_ind, CV_32S); // TODO: remove this conversion on 5.x
normAssert(output_ref_val, output_res_val, (basename + " values").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
normAssert(output_ref_ind, output_res_ind, (basename + " indices").c_str(), l1 ? l1 : default_l1, lInf ? lInf : default_lInf);
expectNoFallbacksFromIE(net);
};
test("top_k");
test("top_k_negative_axis");
test("top_k_smallest");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
}} // namespace