mirror of
https://github.com/opencv/opencv.git
synced 2024-12-04 00:39:11 +08:00
Merge pull request #26106 from Abdurrahheem:ash/add-gatherND
Support for GatherND layer #26106 This PR adds support for GatherND layer. The layer was in comformance deny list initially. ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
a18d793dbd
commit
8263c804de
@ -280,6 +280,17 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<HardmaxLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
/** @brief GatherND layer
|
||||
*
|
||||
* GatherND takes two inputs data and indices of rank r >= 1 and q >= 1 respectively,
|
||||
* and an optional attribute batch_dims. It gathers slices from data into an output tensor.
|
||||
*/
|
||||
class CV_EXPORTS GatherNDLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<GatherNDLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS BaseConvolutionLayer : public Layer
|
||||
{
|
||||
public:
|
||||
|
@ -197,6 +197,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(CumSum, CumSumLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Einsum, EinsumLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Hardmax, HardmaxLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(GatherND, GatherNDLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Scatter, ScatterLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(ScatterND, ScatterNDLayer);
|
||||
|
185
modules/dnn/src/layers/gatherND.cpp
Normal file
185
modules/dnn/src/layers/gatherND.cpp
Normal file
@ -0,0 +1,185 @@
|
||||
#include "../precomp.hpp"
|
||||
#include <opencv2/dnn/shape_utils.hpp>
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
class GatherNDLayerImpl CV_FINAL : public GatherNDLayer
|
||||
{
|
||||
public:
|
||||
GatherNDLayerImpl(const LayerParams& params)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
batch_dims = params.get<int>("batch_dims", 0);
|
||||
}
|
||||
|
||||
void getTypes(const std::vector<MatType>& inputs,
|
||||
const int requiredOutputs,
|
||||
const int requiredInternals,
|
||||
std::vector<MatType>& outputs,
|
||||
std::vector<MatType>& internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(inputs.size() == 2);
|
||||
|
||||
MatType dataType = inputs[0];
|
||||
MatType indicesType = inputs[1];
|
||||
|
||||
// Check that indices are always integer type
|
||||
CV_CheckType(indicesType, indicesType == CV_32S || indicesType == CV_64S,
|
||||
"GatherND: indices must be CV_32S or CV_64S");
|
||||
|
||||
if (preferableTarget == DNN_TARGET_OPENCL_FP16)
|
||||
{
|
||||
CV_CheckType(dataType, dataType == CV_16F || dataType == CV_8S || dataType == CV_8U ||
|
||||
dataType == CV_32S || dataType == CV_64S,
|
||||
"GatherND: unsupported data type for OpenCL FP16 target");
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_CheckType(dataType, dataType == CV_32F || dataType == CV_8S || dataType == CV_8U ||
|
||||
dataType == CV_32S || dataType == CV_64S,
|
||||
"GatherND: unsupported data type");
|
||||
}
|
||||
|
||||
outputs.resize(1, dataType);
|
||||
internals.clear();
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_CheckEQ(inputs.size(), 2ull, "GatherND: requires two inputs");
|
||||
const MatShape& data = inputs[0];
|
||||
const MatShape& indices = inputs[1];
|
||||
|
||||
int r = data.size();
|
||||
int q = indices.size();
|
||||
int last_indices_dim = indices[q - 1];
|
||||
|
||||
CV_CheckGE(r, 1, "GatherND: data rank must be >= 1");
|
||||
CV_CheckGE(q, 1, "GatherND: indices rank must be >= 1");
|
||||
CV_CheckLE(batch_dims, std::min(q, r), "GatherND: batch_dims must be <= min(q, r)");
|
||||
CV_CheckGE(last_indices_dim, 1, "GatherND: last dimension of indices must be >= 1");
|
||||
CV_CheckLE(last_indices_dim, r - batch_dims, "GatherND: last dimension of indices must be <= r - batch_dims");
|
||||
|
||||
MatShape output_shape;
|
||||
output_shape.reserve(q - 1 + r - batch_dims - last_indices_dim);
|
||||
for (int i = 0; i < q - 1; ++i)
|
||||
output_shape.push_back(indices[i]);
|
||||
for (int i = batch_dims + last_indices_dim; i < r; ++i)
|
||||
output_shape.push_back(data[i]);
|
||||
|
||||
outputs.assign(1, output_shape);
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
const Mat& data = inputs[0];
|
||||
const Mat& indices = inputs[1];
|
||||
Mat& out = outputs[0];
|
||||
|
||||
int dtype = data.depth();
|
||||
int itype = indices.depth();
|
||||
|
||||
switch (itype) {
|
||||
case CV_32S:
|
||||
{
|
||||
switch (dtype) {
|
||||
case CV_8U: forward_impl<int32_t, uchar>(data, indices, out); break;
|
||||
case CV_8S: forward_impl<int32_t, schar>(data, indices, out); break;
|
||||
case CV_32S: forward_impl<int32_t, int32_t>(data, indices, out); break;
|
||||
case CV_16F: forward_impl<int32_t, int16_t>(data, indices, out); break;
|
||||
case CV_32F: forward_impl<int32_t, float>(data, indices, out); break;
|
||||
case CV_64F: forward_impl<int32_t, double>(data, indices, out); break;
|
||||
default: CV_Error(Error::StsNotImplemented, "Unsupported data type");
|
||||
}
|
||||
} break;
|
||||
case CV_64S:
|
||||
{
|
||||
switch (dtype) {
|
||||
case CV_8U: forward_impl<int64_t, uchar>(data, indices, out); break;
|
||||
case CV_8S: forward_impl<int64_t, schar>(data, indices, out); break;
|
||||
case CV_32S: forward_impl<int64_t, int32_t>(data, indices, out); break;
|
||||
case CV_16F: forward_impl<int64_t, int16_t>(data, indices, out); break;
|
||||
case CV_32F: forward_impl<int64_t, float>(data, indices, out); break;
|
||||
case CV_64F: forward_impl<int64_t, double>(data, indices, out); break;
|
||||
default: CV_Error(Error::StsNotImplemented, "Unsupported data type");
|
||||
}
|
||||
} break;
|
||||
default: CV_Error(Error::StsNotImplemented, "Unsupported indices type");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename iT, typename dT>
|
||||
void forward_impl(const Mat& data, const Mat& indices, Mat& out)
|
||||
{
|
||||
CV_Assert(out.isContinuous());
|
||||
CV_Assert(indices.isContinuous());
|
||||
CV_Assert(data.isContinuous());
|
||||
|
||||
|
||||
const iT* indices_ptr = indices.ptr<iT>();
|
||||
const dT* data_ptr = data.ptr<dT>();
|
||||
dT* out_ptr = out.ptr<dT>();
|
||||
|
||||
size_t r = data.dims;
|
||||
size_t q = indices.dims;
|
||||
size_t last_indices_dim = indices.size[q - 1];
|
||||
|
||||
std::vector<int> data_strides(r);
|
||||
data_strides[r - 1] = 1;
|
||||
for (int i = r - 2; i >= 0; --i)
|
||||
data_strides[i] = data_strides[i + 1] * data.size[i + 1];
|
||||
|
||||
std::vector<int> indices_strides(q);
|
||||
indices_strides[q - 1] = 1;
|
||||
for (int i = q - 2; i >= 0; --i)
|
||||
indices_strides[i] = indices_strides[i + 1] * indices.size[i + 1];
|
||||
|
||||
const int outer_size = indices.total() / last_indices_dim;
|
||||
const int inner_size = out.total() / outer_size;
|
||||
const int nstripes = outer_size * inner_size / 1024;
|
||||
|
||||
parallel_for_(Range(0, outer_size), [&](const Range& range) {
|
||||
for (size_t i = range.start; i < range.end; ++i)
|
||||
{
|
||||
const iT* sliced_indices = indices_ptr + i * last_indices_dim;
|
||||
|
||||
size_t offset = 0;
|
||||
for (size_t j = 0; j < last_indices_dim; ++j)
|
||||
{
|
||||
offset += sliced_indices[j] * data_strides[batch_dims + j];
|
||||
}
|
||||
|
||||
if (batch_dims > 0)
|
||||
offset += data_strides[batch_dims - 1] * i;
|
||||
|
||||
// copy data from data to out
|
||||
for (size_t j = 0; j < inner_size; ++j)
|
||||
{
|
||||
out_ptr[i * inner_size + j] = data_ptr[offset + j];
|
||||
}
|
||||
}
|
||||
}, nstripes);
|
||||
}
|
||||
|
||||
private:
|
||||
int batch_dims;
|
||||
};
|
||||
|
||||
Ptr<GatherNDLayer> GatherNDLayer::create(const LayerParams& params)
|
||||
{
|
||||
return Ptr<GatherNDLayer>(new GatherNDLayerImpl(params));
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
@ -199,6 +199,7 @@ private:
|
||||
void parseSimpleLayers (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseEinsum (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseHardmax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseGatherND (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
||||
// Domain: com.microsoft
|
||||
// URL: https://github.com/microsoft/onnxruntime/blob/master/docs/ContribOperators.md
|
||||
@ -3213,6 +3214,15 @@ void ONNXImporter::parseHardmax(LayerParams& layerParams, const opencv_onnx::Nod
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseGatherND(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 2);
|
||||
layerParams.type = "GatherND";
|
||||
int batch_dims = layerParams.get<int>("batch_dims", 0);
|
||||
layerParams.set("batch_dims", batch_dims);
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseEinsum(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
std::vector<MatShape> einsumInpShapes;
|
||||
@ -4006,6 +4016,7 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
|
||||
dispatch["Range"] = &ONNXImporter::parseRange;
|
||||
dispatch["Einsum"] = &ONNXImporter::parseEinsum;
|
||||
dispatch["Hardmax"] = &ONNXImporter::parseHardmax;
|
||||
dispatch["GatherND"] = &ONNXImporter::parseGatherND;
|
||||
|
||||
std::vector<std::string> simpleLayers{"Acos", "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil", "Celu", "Cos",
|
||||
"Cosh", "Dropout", "Erf", "Exp", "Floor", "HardSigmoid", "HardSwish",
|
||||
|
@ -100,9 +100,6 @@
|
||||
"test_eyelike_populate_off_main_diagonal", // Issues::Layer::Can't create layer::Can't create layer "onnx_node_output_0!y" of type "EyeLike" in function 'getLayerInstance'
|
||||
"test_eyelike_with_dtype", // ---- same as above ---
|
||||
"test_eyelike_without_dtype", // ---- same as above ---
|
||||
"test_gathernd_example_float32", // Issues::Layer::Can't create layer
|
||||
"test_gathernd_example_int32", // ---- same as above ---
|
||||
"test_gathernd_example_int32_batch_dim1", // ---- same as above ---
|
||||
"test_gelu_default_1_expanded", // parser: no corresponding layer for CastLike
|
||||
"test_gelu_default_2_expanded", // parser: no corresponding layer for CastLike
|
||||
"test_gelu_tanh_1_expanded", // parser: no corresponding layer for CastLike
|
||||
|
Loading…
Reference in New Issue
Block a user