diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 56683f4c14..01d84d9711 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -494,14 +494,17 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_) MatShape inpShape = outShapes[node_proto.input(0)]; DictValue axes = layerParams.get("axes"); bool keepdims = layerParams.get("keepdims"); - MatShape targetShape = inpShape; + MatShape targetShape; + std::vector shouldDelete(inpShape.size(), false); for (int i = 0; i < axes.size(); i++) { int axis = clamp(axes.get(i), inpShape.size()); - if (keepdims) { - targetShape[axis] = 1; - } else { - targetShape.erase(targetShape.begin() + axis); - } + shouldDelete[axis] = true; + } + for (int axis = 0; axis < inpShape.size(); ++axis){ + if (!shouldDelete[axis]) + targetShape.push_back(inpShape[axis]); + else if (keepdims) + targetShape.push_back(1); } if (inpShape.size() == 3 && axes.size() <= 2)