diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 682418bffa..220cae813e 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet) else if (layer_type == "Gather") { CV_Assert(node_proto.input_size() == 2); - Mat input = getBlob(node_proto, constBlobs, 0); Mat indexMat = getBlob(node_proto, constBlobs, 1); CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1); int index = indexMat.at(0); + int axis = layerParams.get("axis", 0); - Mat out; - if (layerParams.has("axis")) + if ((constBlobs.find(node_proto.input(0)) != constBlobs.end())) { - int axis = layerParams.get("axis"); - + Mat input = getBlob(node_proto, constBlobs, 0); + Mat out; std::vector ranges(input.dims, Range::all()); ranges[axis] = Range(index, index + 1); out = input(ranges); + MatShape outShape = shape(out); + if (outShape.size() > 1) + { + outShape.erase(outShape.begin() + axis); + out.reshape(0, outShape); + } + addConstant(layerParams.name, out, constBlobs, outShapes); + continue; } else { - CV_Assert(index < input.total()); - const int dims = input.dims; - input = input.reshape(1, 1); - input.dims = 2; - out = input.reshape(1, 1).colRange(index, index + 1); - out.dims = dims; + shapeIt = outShapes.find(node_proto.input(0)); + CV_Assert(shapeIt != outShapes.end()); + MatShape inpShape = shapeIt->second; + + LayerParams sliceLp; + sliceLp.type = "Slice"; + sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name; + std::vector begin(inpShape.size(), 0); + std::vector end(inpShape.size(), -1); + begin[axis] = index; + end[axis] = index + 1; + + cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size()); + cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size()); + sliceLp.set("begin", paramBegin); + sliceLp.set("end", paramEnd); + + if (inpShape.size() > 1) + { + opencv_onnx::NodeProto proto; + proto.add_input(node_proto.input(0)); + proto.add_output(sliceLp.name); + addLayer(dstNet, sliceLp, proto, layer_id, outShapes); + + inpShape.erase(inpShape.begin() + axis); + layerParams.type = "Reshape"; + layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size())); + node_proto.set_input(0, sliceLp.name); + } + else + { + layerParams = sliceLp; + } } - addConstant(layerParams.name, out, constBlobs, outShapes); - continue; } else if (layer_type == "Concat") { diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 4c8e66aae1..e932bc6919 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -111,6 +111,17 @@ TEST_P(Test_ONNX_layers, Convolution) testONNXModels("convolution"); } +TEST_P(Test_ONNX_layers, Gather) +{ + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + testONNXModels("gather"); + // GPU plugin unsupported slice for constant + if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16)) + applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH); + testONNXModels("gather_scalar", npy, 0, 0, false, false); +} + TEST_P(Test_ONNX_layers, Convolution3D) { #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)