Merge pull request #18834 from l-bat:update_reducemax

This commit is contained in:
Alexander Alekhin 2020-11-17 21:14:10 +00:00
commit f7e8dc770a
2 changed files with 32 additions and 5 deletions

View File

@ -551,11 +551,36 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
CV_Assert(axes.size() <= inpShape.size() - 2); CV_Assert(axes.size() <= inpShape.size() - 2);
std::vector<int> kernel_size(inpShape.size() - 2, 1); std::vector<int> kernel_size(inpShape.size() - 2, 1);
for (int i = 0; i < axes.size(); i++) { if (axes.size() == 1 && (clamp(axes.get<int>(0), inpShape.size()) <= 1))
int axis = clamp(axes.get<int>(i), inpShape.size()); {
CV_Assert_N(axis >= 2 + i, axis < inpShape.size()); int axis = clamp(axes.get<int>(0), inpShape.size());
kernel_size[axis - 2] = inpShape[axis]; MatShape newShape = inpShape;
newShape[axis + 1] = total(newShape, axis + 1);
newShape.resize(axis + 2);
newShape.insert(newShape.begin(), 2 - axis, 1);
LayerParams reshapeLp;
reshapeLp.type = "Reshape";
reshapeLp.name = layerParams.name + "/reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size()));
node_proto.set_output(0, reshapeLp.name);
addLayer(reshapeLp, node_proto);
kernel_size.resize(2);
kernel_size[0] = inpShape[axis];
node_proto.set_input(0, node_proto.output(0));
} }
else
{
for (int i = 0; i < axes.size(); i++) {
int axis = clamp(axes.get<int>(i), inpShape.size());
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
kernel_size[axis - 2] = inpShape[axis];
}
}
LayerParams poolLp = layerParams; LayerParams poolLp = layerParams;
poolLp.name = layerParams.name + "/avg"; poolLp.name = layerParams.name + "/avg";
CV_Assert(layer_id.find(poolLp.name) == layer_id.end()); CV_Assert(layer_id.find(poolLp.name) == layer_id.end());

View File

@ -267,9 +267,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
testONNXModels("reduce_sum"); testONNXModels("reduce_sum");
} }
TEST_P(Test_ONNX_layers, ReduceMaxGlobal) TEST_P(Test_ONNX_layers, ReduceMax)
{ {
testONNXModels("reduce_max"); testONNXModels("reduce_max");
testONNXModels("reduce_max_axis_0");
testONNXModels("reduce_max_axis_1");
} }
TEST_P(Test_ONNX_layers, Scale) TEST_P(Test_ONNX_layers, Scale)