mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Fix Reduce Mean error for MobileNets DNN
Fix for index error for Reduce Mean Correct Reduce Mean indexing error
This commit is contained in:
parent
ed3591ed1f
commit
e05c2e0f1d
@ -494,14 +494,17 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
DictValue axes = layerParams.get("axes");
|
||||
bool keepdims = layerParams.get<int>("keepdims");
|
||||
MatShape targetShape = inpShape;
|
||||
MatShape targetShape;
|
||||
std::vector<bool> shouldDelete(inpShape.size(), false);
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||
if (keepdims) {
|
||||
targetShape[axis] = 1;
|
||||
} else {
|
||||
targetShape.erase(targetShape.begin() + axis);
|
||||
}
|
||||
shouldDelete[axis] = true;
|
||||
}
|
||||
for (int axis = 0; axis < inpShape.size(); ++axis){
|
||||
if (!shouldDelete[axis])
|
||||
targetShape.push_back(inpShape[axis]);
|
||||
else if (keepdims)
|
||||
targetShape.push_back(1);
|
||||
}
|
||||
|
||||
if (inpShape.size() == 3 && axes.size() <= 2)
|
||||
|
Loading…
Reference in New Issue
Block a user