diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 41ff6c9b1e..6583d6cf62 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -475,7 +475,8 @@ void ONNXImporter::populateNet() for (int j = 0; j < inpShape.size(); ++j) { inpShape[j] = tensorShape.dim(j).dim_value(); - if (!tensorShape.dim(j).dim_param().empty()) + // NHW, NCHW(NHWC), NCDHW(NDHWC); do not set this flag if only N is dynamic + if (!tensorShape.dim(j).dim_param().empty() && !(j == 0 && inpShape.size() >= 3)) hasDynamicShapes = true; } if (!inpShape.empty() && !hasDynamicShapes) @@ -1407,6 +1408,16 @@ void ONNXImporter::parseMul(LayerParams& layerParams, const opencv_onnx::NodePro //Replace input to Power node_proto.set_input(1, powerParams.name); } + + const MatShape& broadShape = outShapes[node_proto.input(1)]; + const size_t outShapeSize = outShapes[node_proto.input(0)].size(); + const size_t diff = outShapeSize - broadShape.size(); + + size_t axis; + for (axis = diff; axis < broadShape.size() && broadShape[axis - diff] == 1; ++axis) {} + + CV_Assert(axis != outShapeSize); + layerParams.set("axis", static_cast(axis)); layerParams.type = "Scale"; } addLayer(layerParams, node_proto); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 07a0290b9b..bd96729be1 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -284,6 +284,7 @@ TEST_P(Test_ONNX_layers, Scale) if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); testONNXModels("scale"); + testONNXModels("scale_broadcast", npy, 0, 0, false, true, 3); } TEST_P(Test_ONNX_layers, ReduceMean3D) @@ -831,6 +832,7 @@ TEST_P(Test_ONNX_layers, DynamicAxes) testONNXModels("resize_opset11_torch1.6_dynamic_axes"); testONNXModels("average_pooling_dynamic_axes"); testONNXModels("maxpooling_sigmoid_dynamic_axes"); + testONNXModels("dynamic_batch"); } TEST_P(Test_ONNX_layers, MaxPool1d)