mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 01:13:28 +08:00
Merge pull request #24092 from Aser-Abdelfatah:GSoC_Support_GatherElements_ONNX
GSoC Add ONNX Support for GatherElements #24092 Merge with: https://github.com/opencv/opencv_extra/pull/1082 Adds support to the ONNX operator GatherElements [operator docs](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) Added tests to opencv_extra at pull request https://github.com/opencv/opencv_extra/pull/1082 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
014e8485b5
commit
240b245105
@ -343,6 +343,22 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<GatherLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
/** @brief GatherElements layer
|
||||
* GatherElements takes two inputs data and indices of the same rank r >= 1 and an optional attribute axis and works such that:
|
||||
* output[i][j][k] = data[index[i][j][k]][j][k] if axis = 0 and r = 3
|
||||
* output[i][j][k] = data[i][index[i][j][k]][k] if axis = 1 and r = 3
|
||||
* output[i][j][k] = data[i][j][index[i][j][k]] if axis = 2 and r = 3
|
||||
*
|
||||
* Gather, on the other hand, takes a data tensor of rank r >= 1, and indices tensor of rank q, and works such that:
|
||||
* it gathers the enteries along axis dimension of the input data indexed by indices and concatenates them in an output tensor of rank q + (r - 1)
|
||||
* e.g. If axis = 0, let k = indices[i_{0}, ..., i_{q-1}] then output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}]:
|
||||
**/
|
||||
class CV_EXPORTS GatherElementsLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<GatherElementsLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS PoolingLayer : public Layer
|
||||
{
|
||||
public:
|
||||
|
@ -633,6 +633,56 @@ PERF_TEST_P_(Layer_LayerNormExpanded, DISABLED_LayerNormExpanded)
|
||||
test_layer({N, H ,W});
|
||||
}
|
||||
|
||||
struct Layer_GatherElements : public TestBaseWithParam<tuple<Backend, Target> >
|
||||
{
|
||||
void test_layer(const std::vector<int>& data_shape, const std::vector<int>& indices_shape, int axis = 0)
|
||||
{
|
||||
int backendId = get<0>(GetParam());
|
||||
int targetId = get<1>(GetParam());
|
||||
|
||||
Mat data(data_shape, CV_32FC1);
|
||||
Mat indices(indices_shape, CV_32FC1);
|
||||
|
||||
randu(data, 0.f, 1.f);
|
||||
randu(indices, 0, data_shape[axis]);
|
||||
|
||||
Net net;
|
||||
LayerParams lp;
|
||||
lp.type = "GatherElements";
|
||||
lp.name = "testLayer";
|
||||
lp.set("axis", axis);
|
||||
int id = net.addLayerToPrev(lp.name, lp.type, lp);
|
||||
net.connect(0, 0, id, 0);
|
||||
net.connect(0, 1, id, 1);
|
||||
|
||||
// warmup
|
||||
{
|
||||
std::vector<String> inpNames(3);
|
||||
inpNames[0] = "data";
|
||||
inpNames[1] = "indices";
|
||||
net.setInputsNames(inpNames);
|
||||
net.setInput(data, inpNames[0]);
|
||||
net.setInput(indices, inpNames[1]);
|
||||
|
||||
net.setPreferableBackend(backendId);
|
||||
net.setPreferableTarget(targetId);
|
||||
Mat out = net.forward();
|
||||
}
|
||||
|
||||
TEST_CYCLE()
|
||||
{
|
||||
Mat res = net.forward();
|
||||
}
|
||||
|
||||
SANITY_CHECK_NOTHING();
|
||||
}
|
||||
};
|
||||
|
||||
PERF_TEST_P_(Layer_GatherElements, GatherElements)
|
||||
{
|
||||
test_layer({2700, 1, 2914}, {2700, 1, 81}, 2);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false));
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||
#ifdef HAVE_CUDA
|
||||
@ -642,6 +692,7 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_Scatter, testing::Values(std::make_tuple(DNN
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_ScatterND, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||
INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
|
||||
|
||||
|
||||
typedef TestBaseWithParam<tuple<Vec4i, int, bool, tuple<Backend, Target> > > Layer_FullyConnected;
|
||||
|
@ -157,6 +157,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Gather, GatherLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(GatherElements, GatherElementsLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer);
|
||||
|
||||
|
154
modules/dnn/src/layers/gather_elements_layer.cpp
Normal file
154
modules/dnn/src/layers/gather_elements_layer.cpp
Normal file
@ -0,0 +1,154 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
static inline int calculateOffset(int outer_dim, const MatShape &shape_indices, int axis_skip, const MatStep &step_data) {
|
||||
int offset = 0;
|
||||
for (int axis = static_cast<int>(shape_indices.size()) - 2; axis >= 0; axis--) {
|
||||
int dim = shape_indices[axis];
|
||||
if (axis != axis_skip) {
|
||||
offset += (outer_dim % dim) * step_data[axis];
|
||||
}
|
||||
outer_dim /= dim;
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
class GatherElementsLayerImpl CV_FINAL : public GatherElementsLayer
|
||||
{
|
||||
public:
|
||||
GatherElementsLayerImpl(const LayerParams& params)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
axis = params.get<int>("axis", 0);
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_CheckEQ(inputs.size(), 2ull, "GatherElements: requires two inputs");
|
||||
|
||||
const auto &data = inputs[0];
|
||||
const auto &indices = inputs[1];
|
||||
CV_CheckEQ(data.size(), indices.size(), "GatherElements: data and indices should have the same dimension");
|
||||
|
||||
int normalized_axis = normalize_axis(axis, static_cast<int>(data.size()));
|
||||
CV_CheckGE(normalized_axis, 0, "GatherElements: axis out of range");
|
||||
CV_CheckLT(normalized_axis, static_cast<int>(data.size()), "GatherElements: axis out of range");
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
if (i != normalized_axis) {
|
||||
CV_CheckEQ(data[i], indices[i], "GatherElements: shape mismatched");
|
||||
}
|
||||
}
|
||||
|
||||
outputs.assign(1, inputs[1]); // shape of output is same as indices
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
|
||||
std::vector<Mat> inputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
|
||||
const auto &data = inputs[0];
|
||||
axis = normalize_axis(axis, data.dims);
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
const Mat& data = inputs[0];
|
||||
const Mat& indices = inputs[1];
|
||||
Mat& out = outputs[0];
|
||||
|
||||
typeDispatch(outputs[0].type(), data, indices, out);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void forward_impl(const Mat& data_, const Mat& indices_, Mat& out_)
|
||||
{
|
||||
const auto *ptr_data = data_.ptr<const T>();
|
||||
const auto *ptr_indices = indices_.ptr<const T>();
|
||||
auto *ptr_out = out_.ptr<T>();
|
||||
|
||||
const auto shape_data = shape(data_);
|
||||
const auto &step_data = data_.step;
|
||||
const auto shape_indices = shape(indices_);
|
||||
|
||||
int inner_most_dim = shape_indices.back();
|
||||
int axis_dim = shape_data[axis];
|
||||
size_t axis_step = static_cast<size_t>(step_data[axis] / sizeof(T));
|
||||
|
||||
bool innermost_axis = axis == static_cast<int>(shape_data.size() - 1);
|
||||
|
||||
auto fn = [&](const Range &r) {
|
||||
for (int i = r.start; i < r.end; i++) {
|
||||
auto *data = ptr_data + static_cast<size_t>(calculateOffset(i, shape_indices, axis, step_data) / sizeof(T));
|
||||
auto *indices = ptr_indices + i * inner_most_dim;
|
||||
auto *out = ptr_out + i * inner_most_dim;
|
||||
|
||||
if (innermost_axis) {
|
||||
for (int j = 0; j < inner_most_dim; j++) {
|
||||
int index = static_cast<int>((indices[j] + axis_dim)) % axis_dim; // TODO: Check out-of-range index
|
||||
out[j] = data[index];
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; j < inner_most_dim; j++) {
|
||||
int index = static_cast<int>(indices[j] + axis_dim) % axis_dim; // TODO: Check out-of-range index
|
||||
out[j] = data[index * axis_step + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int outer_dims = total(shape_indices, 0, shape_indices.size() - 1);
|
||||
double nstripes = static_cast<size_t>(outer_dims * inner_most_dim * (1 / 1024.0));
|
||||
parallel_for_(Range(0, outer_dims), fn, nstripes);
|
||||
}
|
||||
|
||||
template<typename... Args>
|
||||
inline void typeDispatch(const int type, Args&&... args)
|
||||
{
|
||||
switch (type)
|
||||
{
|
||||
case CV_8U:
|
||||
forward_impl<uint8_t>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_32S:
|
||||
forward_impl<int32_t>(std::forward<Args>(args)...);
|
||||
break;
|
||||
case CV_32F:
|
||||
forward_impl<float>(std::forward<Args>(args)...);
|
||||
break;
|
||||
default:
|
||||
CV_Error(cv::Error::BadDepth, "DNN/GatherElements: Unsupported type.");
|
||||
};
|
||||
}
|
||||
|
||||
private:
|
||||
int axis;
|
||||
};
|
||||
|
||||
Ptr<GatherElementsLayer> GatherElementsLayer::create(const LayerParams& params)
|
||||
{
|
||||
return makePtr<GatherElementsLayerImpl>(params);
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
@ -179,6 +179,7 @@ private:
|
||||
void parseCast (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseConstantFill (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseGather (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseGatherElements (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseResize (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseUpsample (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
@ -2553,6 +2554,53 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseGatherElements(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
CV_CheckEQ(node_proto.input_size(), 2, "GatherElements: two inputs are required");
|
||||
|
||||
size_t num_const = 0;
|
||||
for (size_t i = 0; i < node_proto.input_size(); ++i){
|
||||
if (constBlobs.find(node_proto.input(i)) != constBlobs.end())
|
||||
++num_const;
|
||||
}
|
||||
|
||||
if (num_const == node_proto.input_size())
|
||||
{
|
||||
std::vector<Mat> inputs, output;
|
||||
for (size_t i = 0; i < node_proto.input_size(); i++) {
|
||||
Mat blob = getBlob(node_proto, i);
|
||||
if (i == 1) { // indices, from int32/int64 to float32 for compatibility
|
||||
blob.convertTo(blob, CV_32F);
|
||||
}
|
||||
inputs.push_back(blob);
|
||||
}
|
||||
runLayer(layerParams, inputs, output);
|
||||
CV_Assert(output.size() == 1);
|
||||
addConstant(node_proto.output(0), output[0]);
|
||||
return;
|
||||
} else if (num_const > 0) {
|
||||
for (size_t i = 0; i < node_proto.input_size(); i++) {
|
||||
if (constBlobs.find(node_proto.input(i)) != constBlobs.end()) {
|
||||
Mat blob = getBlob(node_proto, i);
|
||||
if (i == 1) { // indices, from int32/int64 to float32 for compatibility
|
||||
blob.convertTo(blob, CV_32F);
|
||||
}
|
||||
|
||||
LayerParams constParams;
|
||||
constParams.name = node_proto.input(i);
|
||||
constParams.type = "Const";
|
||||
constParams.blobs.push_back(blob);
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_output(constParams.name);
|
||||
addLayer(constParams, proto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseConcat(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
bool hasVariableInps = false;
|
||||
@ -3901,6 +3949,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
||||
dispatch["Cast"] = &ONNXImporter::parseCast;
|
||||
dispatch["ConstantFill"] = dispatch["ConstantOfShape"] = &ONNXImporter::parseConstantFill;
|
||||
dispatch["Gather"] = &ONNXImporter::parseGather;
|
||||
dispatch["GatherElements"] = &ONNXImporter::parseGatherElements;
|
||||
dispatch["Concat"] = &ONNXImporter::parseConcat;
|
||||
dispatch["Resize"] = &ONNXImporter::parseResize;
|
||||
dispatch["Upsample"] = &ONNXImporter::parseUpsample;
|
||||
|
@ -55,6 +55,9 @@
|
||||
"test_flatten_negative_axis1",
|
||||
"test_flatten_negative_axis2",
|
||||
"test_flatten_negative_axis4",
|
||||
"test_gather_elements_0",
|
||||
"test_gather_elements_1",
|
||||
"test_gather_elements_negative_indices",
|
||||
"test_logsoftmax_default_axis",
|
||||
"test_maxpool_2d_dilations",
|
||||
"test_maxpool_2d_same_lower",
|
||||
|
@ -115,9 +115,6 @@
|
||||
"test_gather_0",
|
||||
"test_gather_1",
|
||||
"test_gather_2d_indices",
|
||||
"test_gather_elements_0",
|
||||
"test_gather_elements_1",
|
||||
"test_gather_elements_negative_indices",
|
||||
"test_gather_negative_indices",
|
||||
"test_gathernd_example_float32",
|
||||
"test_gathernd_example_int32",
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "test_precomp.hpp"
|
||||
#include "npy_blob.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
#include <numeric>
|
||||
namespace opencv_test { namespace {
|
||||
|
||||
template<typename TString>
|
||||
@ -2134,6 +2135,34 @@ TEST_P(Test_ONNX_nets, Alexnet)
|
||||
expectNoFallbacksFromIE(net);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_nets, RAFT)
|
||||
{
|
||||
applyTestTag(CV_TEST_TAG_LONG, CV_TEST_TAG_DEBUG_VERYLONG, CV_TEST_TAG_MEMORY_2GB);
|
||||
|
||||
std::string weight_path = _tf("models/optical_flow_estimation_raft_2023aug.onnx", false);
|
||||
std::string img0_path = findDataFile(std::string("gpu/opticalflow/frame0.png"));
|
||||
std::string img1_path = findDataFile(std::string("gpu/opticalflow/frame1.png"));
|
||||
|
||||
Size target_size{480, 360};
|
||||
auto img0 = imread(img0_path);
|
||||
auto img1 = imread(img1_path);
|
||||
auto blob0 = blobFromImage(img0, 1.0, target_size, 0, true);
|
||||
auto blob1 = blobFromImage(img1, 1.0, target_size, 0, true);
|
||||
|
||||
auto net = readNet(weight_path);
|
||||
net.setInput(blob0, "0");
|
||||
net.setInput(blob1, "1");
|
||||
std::vector<std::string> outnames{"12007", "12006"};
|
||||
std::vector<Mat> outs;
|
||||
net.forward(outs, outnames);
|
||||
|
||||
// output 12006 is not checked to save space in opencv_extra since its ref is > 1MB,
|
||||
// and output 12006 is calculated from 12007 so checking 12007 is sufficient.
|
||||
std::string ref_12700_path = _tf("data/output_optical_flow_estimation_raft_2023aug.npy");
|
||||
auto ref0 = blobFromNPY(ref_12700_path);
|
||||
normAssert(ref0, outs[0], "", 1e-5, 1.8e-4);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_nets, Squeezenet)
|
||||
{
|
||||
testONNXModels("squeezenet", pb);
|
||||
|
Loading…
Reference in New Issue
Block a user