mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
Merge pull request #18834 from l-bat:update_reducemax
This commit is contained in:
commit
f7e8dc770a
@ -551,11 +551,36 @@ void ONNXImporter::handleNode(const opencv_onnx::NodeProto& node_proto_)
|
|||||||
|
|
||||||
CV_Assert(axes.size() <= inpShape.size() - 2);
|
CV_Assert(axes.size() <= inpShape.size() - 2);
|
||||||
std::vector<int> kernel_size(inpShape.size() - 2, 1);
|
std::vector<int> kernel_size(inpShape.size() - 2, 1);
|
||||||
for (int i = 0; i < axes.size(); i++) {
|
if (axes.size() == 1 && (clamp(axes.get<int>(0), inpShape.size()) <= 1))
|
||||||
int axis = clamp(axes.get<int>(i), inpShape.size());
|
{
|
||||||
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
|
int axis = clamp(axes.get<int>(0), inpShape.size());
|
||||||
kernel_size[axis - 2] = inpShape[axis];
|
MatShape newShape = inpShape;
|
||||||
|
newShape[axis + 1] = total(newShape, axis + 1);
|
||||||
|
newShape.resize(axis + 2);
|
||||||
|
newShape.insert(newShape.begin(), 2 - axis, 1);
|
||||||
|
|
||||||
|
LayerParams reshapeLp;
|
||||||
|
reshapeLp.type = "Reshape";
|
||||||
|
reshapeLp.name = layerParams.name + "/reshape";
|
||||||
|
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||||
|
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], newShape.size()));
|
||||||
|
|
||||||
|
node_proto.set_output(0, reshapeLp.name);
|
||||||
|
addLayer(reshapeLp, node_proto);
|
||||||
|
|
||||||
|
kernel_size.resize(2);
|
||||||
|
kernel_size[0] = inpShape[axis];
|
||||||
|
node_proto.set_input(0, node_proto.output(0));
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
|
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||||
|
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
|
||||||
|
kernel_size[axis - 2] = inpShape[axis];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
LayerParams poolLp = layerParams;
|
LayerParams poolLp = layerParams;
|
||||||
poolLp.name = layerParams.name + "/avg";
|
poolLp.name = layerParams.name + "/avg";
|
||||||
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
|
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
|
||||||
|
@ -267,9 +267,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
|
|||||||
testONNXModels("reduce_sum");
|
testONNXModels("reduce_sum");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, ReduceMaxGlobal)
|
TEST_P(Test_ONNX_layers, ReduceMax)
|
||||||
{
|
{
|
||||||
testONNXModels("reduce_max");
|
testONNXModels("reduce_max");
|
||||||
|
testONNXModels("reduce_max_axis_0");
|
||||||
|
testONNXModels("reduce_max_axis_1");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, Scale)
|
TEST_P(Test_ONNX_layers, Scale)
|
||||||
|
Loading…
Reference in New Issue
Block a user