Merge pull request #17890 from l-bat:onnx_gather

This commit is contained in:
Maksim Shabunin 2020-07-21 09:38:35 +00:00
commit 5444a6b11c
2 changed files with 56 additions and 13 deletions

View File

@ -1342,33 +1342,65 @@ void ONNXImporter::populateNet(Net dstNet)
else if (layer_type == "Gather")
{
CV_Assert(node_proto.input_size() == 2);
Mat input = getBlob(node_proto, constBlobs, 0);
Mat indexMat = getBlob(node_proto, constBlobs, 1);
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
int index = indexMat.at<int>(0);
int axis = layerParams.get<int>("axis", 0);
Mat out;
if (layerParams.has("axis"))
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
{
int axis = layerParams.get<int>("axis");
Mat input = getBlob(node_proto, constBlobs, 0);
Mat out;
std::vector<cv::Range> ranges(input.dims, Range::all());
ranges[axis] = Range(index, index + 1);
out = input(ranges);
}
else
MatShape outShape = shape(out);
if (outShape.size() > 1)
{
CV_Assert(index < input.total());
const int dims = input.dims;
input = input.reshape(1, 1);
input.dims = 2;
out = input.reshape(1, 1).colRange(index, index + 1);
out.dims = dims;
outShape.erase(outShape.begin() + axis);
out.reshape(0, outShape);
}
addConstant(layerParams.name, out, constBlobs, outShapes);
continue;
}
else
{
shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape inpShape = shapeIt->second;
LayerParams sliceLp;
sliceLp.type = "Slice";
sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
std::vector<int> begin(inpShape.size(), 0);
std::vector<int> end(inpShape.size(), -1);
begin[axis] = index;
end[axis] = index + 1;
cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
sliceLp.set("begin", paramBegin);
sliceLp.set("end", paramEnd);
if (inpShape.size() > 1)
{
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(sliceLp.name);
addLayer(dstNet, sliceLp, proto, layer_id, outShapes);
inpShape.erase(inpShape.begin() + axis);
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
node_proto.set_input(0, sliceLp.name);
}
else
{
layerParams = sliceLp;
}
}
}
else if (layer_type == "Concat")
{
bool hasVariableInps = false;

View File

@ -111,6 +111,17 @@ TEST_P(Test_ONNX_layers, Convolution)
testONNXModels("convolution");
}
TEST_P(Test_ONNX_layers, Gather)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
testONNXModels("gather");
// GPU plugin unsupported slice for constant
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
testONNXModels("gather_scalar", npy, 0, 0, false, false);
}
TEST_P(Test_ONNX_layers, Convolution3D)
{
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)