mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 09:25:45 +08:00
add Gather implementation
This commit is contained in:
parent
448e3a7e58
commit
65f71ce2eb
@ -301,6 +301,14 @@ CV__DNN_INLINE_NS_BEGIN
|
||||
static Ptr<ArgLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
/** @brief Gather layer
|
||||
*/
|
||||
class CV_EXPORTS GatherLayer : public Layer
|
||||
{
|
||||
public:
|
||||
static Ptr<GatherLayer> create(const LayerParams& params);
|
||||
};
|
||||
|
||||
class CV_EXPORTS PoolingLayer : public Layer
|
||||
{
|
||||
public:
|
||||
|
@ -147,6 +147,7 @@ void initializeLayerFactory()
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Const, ConstLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Arg, ArgLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Reciprocal, ReciprocalLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Gather, GatherLayer);
|
||||
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);
|
||||
|
91
modules/dnn/src/layers/gather_layer.cpp
Normal file
91
modules/dnn/src/layers/gather_layer.cpp
Normal file
@ -0,0 +1,91 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "layers_common.hpp"
|
||||
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
class GatherLayerImpl CV_FINAL : public GatherLayer
|
||||
{
|
||||
public:
|
||||
GatherLayerImpl(const LayerParams& params)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
m_axis = params.get<int>("axis", 0);
|
||||
m_real_ndims = params.get<int>("real_ndims", -1);
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV;
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_CheckEQ(inputs.size(), 2ull, "");
|
||||
MatShape inpShape = inputs[0];
|
||||
const int axis = normalize_axis(m_axis, inpShape);
|
||||
|
||||
inpShape.erase(inpShape.begin() + axis);
|
||||
auto end = m_real_ndims == -1 ? inputs[1].end() : inputs[1].begin() + m_real_ndims;
|
||||
inpShape.insert(inpShape.begin() + axis, inputs[1].begin(), end);
|
||||
|
||||
outputs.assign(1, inpShape);
|
||||
return false;
|
||||
}
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
const Mat& inp = inputs[0];
|
||||
const Mat& indices = inputs[1];
|
||||
Mat& out = outputs[0];
|
||||
|
||||
const int axis = normalize_axis(m_axis, shape(inp));
|
||||
|
||||
const size_t outer_size = axis == 0 ? inp.total() : inp.step1(axis - 1);
|
||||
const size_t outer_dims = inp.total() / outer_size;
|
||||
const size_t inner_size = inp.step1(axis);
|
||||
|
||||
const float* idx = indices.ptr<const float>(); // TODO: change type to integer in the future.
|
||||
const char* src = inp.ptr<const char>();
|
||||
char* dst = out.ptr<char>();
|
||||
|
||||
const size_t es = inp.elemSize1();
|
||||
for (size_t i = 0; i < outer_dims; ++i)
|
||||
{
|
||||
const size_t src_offset = i * outer_size;
|
||||
for (size_t j = 0 ; j < indices.total(); ++j)
|
||||
{
|
||||
const size_t index = (static_cast<int>(idx[j]) + inp.size[axis]) % inp.size[axis];
|
||||
const size_t new_offset = src_offset + index * inp.step1(axis);
|
||||
std::memcpy(dst, src + new_offset * es, inner_size * es);
|
||||
dst += inner_size * es;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// The axis to gather along
|
||||
int m_axis;
|
||||
int m_real_ndims;
|
||||
};
|
||||
|
||||
Ptr<GatherLayer> GatherLayer::create(const LayerParams& params)
|
||||
{
|
||||
return makePtr<GatherLayerImpl>(params);
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
@ -2622,83 +2622,57 @@ void ONNXImporter::parseConstantFill(LayerParams& layerParams, const opencv_onnx
|
||||
addConstant(node_proto.output(0), tensor);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
|
||||
void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
opencv_onnx::NodeProto node_proto = node_proto_;
|
||||
CV_Assert(node_proto.input_size() == 2);
|
||||
Mat indexMat = getBlob(node_proto, 1);
|
||||
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
|
||||
int index = indexMat.at<int>(0);
|
||||
int axis = layerParams.get<int>("axis", 0);
|
||||
CV_CheckEQ(node_proto.input_size(), 2, "");
|
||||
|
||||
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
|
||||
// TODO: get rid of the type conversions and 1-d/0-d special-casing when the time comes
|
||||
if (layer_id.find(node_proto.input(1)) == layer_id.end())
|
||||
{
|
||||
Mat input = getBlob(node_proto, 0);
|
||||
Mat out;
|
||||
std::vector<cv::Range> ranges(input.dims, Range::all());
|
||||
ranges[axis] = Range(index, index + 1);
|
||||
|
||||
out = input(ranges);
|
||||
MatShape outShape = shape(out);
|
||||
if (outShape.size() > 1)
|
||||
int real_ndims = getBlobExtraInfo(node_proto.input(1)).real_ndims;
|
||||
layerParams.set("real_ndims", real_ndims);
|
||||
if (layer_id.find(node_proto.input(0)) == layer_id.end())
|
||||
{
|
||||
outShape.erase(outShape.begin() + axis);
|
||||
out.reshape(0, outShape);
|
||||
} else {
|
||||
out.dims = 1;
|
||||
std::vector<Mat> inputs, output;
|
||||
|
||||
Mat input = getBlob(node_proto, 0);
|
||||
int input_real_ndims = input.dims;
|
||||
int type = input.type();
|
||||
input.convertTo(input, CV_32FC1);
|
||||
inputs.push_back(input);
|
||||
|
||||
Mat indices = getBlob(node_proto, 1);
|
||||
indices.convertTo(indices, CV_32FC1);
|
||||
inputs.push_back(indices);
|
||||
|
||||
runLayer(layerParams, inputs, output);
|
||||
output.back().convertTo(output.back(), type);
|
||||
output.back().dims = std::max(input_real_ndims - real_ndims, 1);
|
||||
addConstant(node_proto.output(0), output.back());
|
||||
return;
|
||||
}
|
||||
addConstant(node_proto.output(0), out);
|
||||
return;
|
||||
}
|
||||
else
|
||||
|
||||
for (int i = 0; i < node_proto.input_size(); ++i)
|
||||
{
|
||||
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
MatShape inpShape = shapeIt->second;
|
||||
|
||||
LayerParams sliceLp;
|
||||
sliceLp.type = "Slice";
|
||||
sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
|
||||
std::vector<int> begin(inpShape.size(), 0);
|
||||
std::vector<int> end(inpShape.size(), INT_MAX);
|
||||
begin[axis] = index;
|
||||
end[axis] = index + 1;
|
||||
|
||||
cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
|
||||
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
|
||||
sliceLp.set("begin", paramBegin);
|
||||
sliceLp.set("end", paramEnd);
|
||||
sliceLp.set("has_dynamic_shapes", hasDynamicShapes);
|
||||
|
||||
if (inpShape.size() > 1)
|
||||
if (layer_id.find(node_proto.input(i)) == layer_id.end())
|
||||
{
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(sliceLp.name);
|
||||
addLayer(sliceLp, proto);
|
||||
|
||||
inpShape.erase(inpShape.begin() + axis);
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("axis", 0);
|
||||
layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||
if (hasDynamicShapes)
|
||||
LayerParams constParams;
|
||||
constParams.name = node_proto.input(i);
|
||||
constParams.type = "Const";
|
||||
Mat blob = getBlob(node_proto, i);
|
||||
if (i == 1)
|
||||
{
|
||||
std::vector<int> dynamicAxes;
|
||||
std::vector<int> inputIndices;
|
||||
for (int index = 0; index < inpShape.size(); ++index)
|
||||
dynamicAxes.push_back(index);
|
||||
for (int index = 0; index < inpShape.size(); ++index)
|
||||
inputIndices.push_back(index);
|
||||
layerParams.set("dynamic_axes", DictValue::arrayInt(dynamicAxes.data(), dynamicAxes.size()));
|
||||
layerParams.set("input_indices", DictValue::arrayInt(inputIndices.data(), inputIndices.size()));
|
||||
blob.convertTo(blob, CV_32FC1);
|
||||
}
|
||||
node_proto.set_input(0, sliceLp.name);
|
||||
}
|
||||
else
|
||||
{
|
||||
layerParams = sliceLp;
|
||||
constParams.blobs.push_back(blob);
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_output(constParams.name);
|
||||
addLayer(constParams, proto);
|
||||
}
|
||||
}
|
||||
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
@ -207,11 +207,16 @@ TEST_P(Test_ONNX_layers, Gather)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
testONNXModels("gather");
|
||||
testONNXModels("gather", npy, 0, 0, false, false);
|
||||
testONNXModels("gather_scalar", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, GatherMulti)
|
||||
{
|
||||
// GPU plugin unsupported slice for constant
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
testONNXModels("gather_scalar", npy, 0, 0, false, false);
|
||||
testONNXModels("gather_multi", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Convolution3D)
|
||||
@ -1424,7 +1429,7 @@ TEST_P(Test_ONNX_layers, GatherMultiOutput)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE);
|
||||
#endif
|
||||
|
||||
testONNXModels("gather_multi_output");
|
||||
testONNXModels("gather_multi_output", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, DynamicAxes_squeeze_and_conv)
|
||||
@ -1475,7 +1480,7 @@ TEST_P(Test_ONNX_layers, DynamicAxes_gather)
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
testONNXModels("gather_dynamic_axes");
|
||||
testONNXModels("gather_dynamic_axes", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, DynamicAxes_gather_scalar)
|
||||
@ -1504,7 +1509,7 @@ TEST_P(Test_ONNX_layers, DynamicAxes_gather_scalar)
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
testONNXModels("gather_scalar_dynamic_axes");
|
||||
testONNXModels("gather_scalar_dynamic_axes", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, DynamicAxes_slice)
|
||||
|
Loading…
Reference in New Issue
Block a user