mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 03:30:34 +08:00
Merge pull request #18845 from joegeisbauer:fix_reduce_mean_index_error
This commit is contained in:
commit
0401d5920c
@ -500,14 +500,17 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
DictValue axes = layerParams.get("axes");
|
||||
bool keepdims = layerParams.get<int>("keepdims");
|
||||
MatShape targetShape = inpShape;
|
||||
MatShape targetShape;
|
||||
std::vector<bool> shouldDelete(inpShape.size(), false);
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||
if (keepdims) {
|
||||
targetShape[axis] = 1;
|
||||
} else {
|
||||
targetShape.erase(targetShape.begin() + axis);
|
||||
}
|
||||
shouldDelete[axis] = true;
|
||||
}
|
||||
for (int axis = 0; axis < inpShape.size(); ++axis){
|
||||
if (!shouldDelete[axis])
|
||||
targetShape.push_back(inpShape[axis]);
|
||||
else if (keepdims)
|
||||
targetShape.push_back(1);
|
||||
}
|
||||
|
||||
if (inpShape.size() == 3 && axes.size() <= 2)
|
||||
|
Loading…
Reference in New Issue
Block a user