mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #18845 from joegeisbauer:fix_reduce_mean_index_error
This commit is contained in:
commit
0401d5920c
@ -500,14 +500,17 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
|
|||||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||||
DictValue axes = layerParams.get("axes");
|
DictValue axes = layerParams.get("axes");
|
||||||
bool keepdims = layerParams.get<int>("keepdims");
|
bool keepdims = layerParams.get<int>("keepdims");
|
||||||
MatShape targetShape = inpShape;
|
MatShape targetShape;
|
||||||
|
std::vector<bool> shouldDelete(inpShape.size(), false);
|
||||||
for (int i = 0; i < axes.size(); i++) {
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
int axis = clamp(axes.get<int>(i), inpShape.size());
|
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||||
if (keepdims) {
|
shouldDelete[axis] = true;
|
||||||
targetShape[axis] = 1;
|
}
|
||||||
} else {
|
for (int axis = 0; axis < inpShape.size(); ++axis){
|
||||||
targetShape.erase(targetShape.begin() + axis);
|
if (!shouldDelete[axis])
|
||||||
}
|
targetShape.push_back(inpShape[axis]);
|
||||||
|
else if (keepdims)
|
||||||
|
targetShape.push_back(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inpShape.size() == 3 && axes.size() <= 2)
|
if (inpShape.size() == 3 && axes.size() <= 2)
|
||||||
|
Loading…
Reference in New Issue
Block a user