mirror of
https://github.com/opencv/opencv.git
synced 2024-12-04 00:39:11 +08:00
Merge pull request #17890 from l-bat:onnx_gather
This commit is contained in:
commit
5444a6b11c
@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
else if (layer_type == "Gather")
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 2);
|
||||
Mat input = getBlob(node_proto, constBlobs, 0);
|
||||
Mat indexMat = getBlob(node_proto, constBlobs, 1);
|
||||
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
|
||||
int index = indexMat.at<int>(0);
|
||||
int axis = layerParams.get<int>("axis", 0);
|
||||
|
||||
Mat out;
|
||||
if (layerParams.has("axis"))
|
||||
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
|
||||
{
|
||||
int axis = layerParams.get<int>("axis");
|
||||
|
||||
Mat input = getBlob(node_proto, constBlobs, 0);
|
||||
Mat out;
|
||||
std::vector<cv::Range> ranges(input.dims, Range::all());
|
||||
ranges[axis] = Range(index, index + 1);
|
||||
|
||||
out = input(ranges);
|
||||
MatShape outShape = shape(out);
|
||||
if (outShape.size() > 1)
|
||||
{
|
||||
outShape.erase(outShape.begin() + axis);
|
||||
out.reshape(0, outShape);
|
||||
}
|
||||
addConstant(layerParams.name, out, constBlobs, outShapes);
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Assert(index < input.total());
|
||||
const int dims = input.dims;
|
||||
input = input.reshape(1, 1);
|
||||
input.dims = 2;
|
||||
out = input.reshape(1, 1).colRange(index, index + 1);
|
||||
out.dims = dims;
|
||||
shapeIt = outShapes.find(node_proto.input(0));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
MatShape inpShape = shapeIt->second;
|
||||
|
||||
LayerParams sliceLp;
|
||||
sliceLp.type = "Slice";
|
||||
sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
|
||||
std::vector<int> begin(inpShape.size(), 0);
|
||||
std::vector<int> end(inpShape.size(), -1);
|
||||
begin[axis] = index;
|
||||
end[axis] = index + 1;
|
||||
|
||||
cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
|
||||
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
|
||||
sliceLp.set("begin", paramBegin);
|
||||
sliceLp.set("end", paramEnd);
|
||||
|
||||
if (inpShape.size() > 1)
|
||||
{
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(sliceLp.name);
|
||||
addLayer(dstNet, sliceLp, proto, layer_id, outShapes);
|
||||
|
||||
inpShape.erase(inpShape.begin() + axis);
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||
node_proto.set_input(0, sliceLp.name);
|
||||
}
|
||||
else
|
||||
{
|
||||
layerParams = sliceLp;
|
||||
}
|
||||
}
|
||||
addConstant(layerParams.name, out, constBlobs, outShapes);
|
||||
continue;
|
||||
}
|
||||
else if (layer_type == "Concat")
|
||||
{
|
||||
|
@ -111,6 +111,17 @@ TEST_P(Test_ONNX_layers, Convolution)
|
||||
testONNXModels("convolution");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Gather)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
testONNXModels("gather");
|
||||
// GPU plugin unsupported slice for constant
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
testONNXModels("gather_scalar", npy, 0, 0, false, false);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Convolution3D)
|
||||
{
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)
|
||||
|
Loading…
Reference in New Issue
Block a user