From 9c5d7716e273284d4973eefdef079436c7caaa9d Mon Sep 17 00:00:00 2001 From: SamFC10 Date: Fri, 17 Sep 2021 17:40:57 +0530 Subject: [PATCH] fix for unsqueeze opset version 13 --- modules/dnn/src/onnx/onnx_importer.cpp | 12 ++++++++++-- modules/dnn/test/test_onnx_importer.cpp | 1 + 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 2cd36dc94d..5343c05361 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1693,8 +1693,16 @@ void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::Nod void ONNXImporter::parseUnsqueeze(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) { - CV_Assert(node_proto.input_size() == 1); - DictValue axes = layerParams.get("axes"); + CV_Assert(node_proto.input_size() == 1 || node_proto.input_size() == 2); + DictValue axes; + if (node_proto.input_size() == 2) + { + Mat blob = getBlob(node_proto, 1); + axes = DictValue::arrayInt(blob.ptr(), blob.total()); + } + else + axes = layerParams.get("axes"); + if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) { // Constant input. diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index f55510ec7b..b4ecd4601c 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -605,6 +605,7 @@ TEST_P(Test_ONNX_layers, DynamicReshape) TEST_P(Test_ONNX_layers, Reshape) { testONNXModels("unsqueeze"); + testONNXModels("unsqueeze_opset_13"); } TEST_P(Test_ONNX_layers, Squeeze)