mirror of
https://github.com/opencv/opencv.git
synced 2025-06-21 02:20:50 +08:00
Merge pull request #18299 from l-bat:onnx_reduce_max
This commit is contained in:
commit
83807811cd
@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
||||||
}
|
}
|
||||||
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
|
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
|
||||||
layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
|
||||||
{
|
{
|
||||||
CV_Assert(node_proto.input_size() == 1);
|
CV_Assert(node_proto.input_size() == 1);
|
||||||
layerParams.type = "Pooling";
|
layerParams.type = "Pooling";
|
||||||
String pool;
|
String pool;
|
||||||
if (layer_type == "GlobalMaxPool")
|
if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
|
||||||
pool = "MAX";
|
pool = "MAX";
|
||||||
else if (layer_type == "ReduceSum")
|
else if (layer_type == "ReduceSum")
|
||||||
pool = "SUM";
|
pool = "SUM";
|
||||||
else
|
else
|
||||||
pool = "AVE";
|
pool = "AVE";
|
||||||
layerParams.set("pool", pool);
|
layerParams.set("pool", pool);
|
||||||
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
|
layerParams.set("global_pooling", !layerParams.has("axes"));
|
||||||
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
|
||||||
{
|
{
|
||||||
if (!layerParams.has("axes"))
|
|
||||||
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
|
|
||||||
|
|
||||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||||
DictValue axes = layerParams.get("axes");
|
DictValue axes = layerParams.get("axes");
|
||||||
bool keepdims = layerParams.get<int>("keepdims");
|
bool keepdims = layerParams.get<int>("keepdims");
|
||||||
@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
layerParams.type = "Reshape";
|
layerParams.type = "Reshape";
|
||||||
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
|
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
|
||||||
|
|
||||||
|
node_proto.set_input(0, node_proto.output(0));
|
||||||
|
node_proto.set_output(0, layerParams.name);
|
||||||
|
}
|
||||||
|
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
|
||||||
|
{
|
||||||
|
CV_CheckEQ(layerParams.get<int>("keepdims"), 0, (layer_type + " layer only supports keepdims = false").c_str());
|
||||||
|
LayerParams reshapeLp;
|
||||||
|
reshapeLp.name = layerParams.name + "/reshape";
|
||||||
|
reshapeLp.type = "Reshape";
|
||||||
|
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||||
|
int newShape[] = {1, 1, 1, -1};
|
||||||
|
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
|
||||||
|
|
||||||
|
opencv_onnx::NodeProto proto;
|
||||||
|
proto.add_input(node_proto.input(0));
|
||||||
|
proto.add_output(reshapeLp.name);
|
||||||
|
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
|
||||||
|
|
||||||
|
LayerParams poolLp = layerParams;
|
||||||
|
poolLp.name = layerParams.name + "/pool";
|
||||||
|
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
|
||||||
|
|
||||||
|
node_proto.set_input(0, reshapeLp.name);
|
||||||
|
node_proto.set_output(0, poolLp.name);
|
||||||
|
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
|
||||||
|
|
||||||
|
layerParams.type = "Reshape";
|
||||||
|
int targetShape[] = {1};
|
||||||
|
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
|
||||||
|
|
||||||
node_proto.set_input(0, node_proto.output(0));
|
node_proto.set_input(0, node_proto.output(0));
|
||||||
node_proto.set_output(0, layerParams.name);
|
node_proto.set_output(0, layerParams.name);
|
||||||
}
|
}
|
||||||
@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
|
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
|
||||||
default: type = blob.type();
|
default: type = blob.type();
|
||||||
}
|
}
|
||||||
blob.convertTo(blob, type);
|
Mat dst;
|
||||||
addConstant(layerParams.name, blob, constBlobs, outShapes);
|
blob.convertTo(dst, type);
|
||||||
|
dst.dims = blob.dims;
|
||||||
|
addConstant(layerParams.name, dst, constBlobs, outShapes);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
{
|
{
|
||||||
outShape.erase(outShape.begin() + axis);
|
outShape.erase(outShape.begin() + axis);
|
||||||
out.reshape(0, outShape);
|
out.reshape(0, outShape);
|
||||||
|
} else {
|
||||||
|
out.dims = 1;
|
||||||
}
|
}
|
||||||
addConstant(layerParams.name, out, constBlobs, outShapes);
|
addConstant(layerParams.name, out, constBlobs, outShapes);
|
||||||
continue;
|
continue;
|
||||||
|
@ -262,6 +262,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
|
|||||||
testONNXModels("reduce_sum");
|
testONNXModels("reduce_sum");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, ReduceMaxGlobal)
|
||||||
|
{
|
||||||
|
testONNXModels("reduce_max");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, ReduceMean3D)
|
TEST_P(Test_ONNX_layers, ReduceMean3D)
|
||||||
{
|
{
|
||||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
|
||||||
|
Loading…
Reference in New Issue
Block a user