mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 21:20:18 +08:00
Support global reduce ops
This commit is contained in:
parent
6b674709b8
commit
b542a1804c
@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
||||
}
|
||||
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
|
||||
layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
||||
layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 1);
|
||||
layerParams.type = "Pooling";
|
||||
String pool;
|
||||
if (layer_type == "GlobalMaxPool")
|
||||
if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
|
||||
pool = "MAX";
|
||||
else if (layer_type == "ReduceSum")
|
||||
pool = "SUM";
|
||||
else
|
||||
pool = "AVE";
|
||||
layerParams.set("pool", pool);
|
||||
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
|
||||
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
||||
layerParams.set("global_pooling", !layerParams.has("axes"));
|
||||
if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
|
||||
{
|
||||
if (!layerParams.has("axes"))
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
|
||||
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
DictValue axes = layerParams.get("axes");
|
||||
bool keepdims = layerParams.get<int>("keepdims");
|
||||
@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
|
||||
|
||||
node_proto.set_input(0, node_proto.output(0));
|
||||
node_proto.set_output(0, layerParams.name);
|
||||
}
|
||||
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
|
||||
{
|
||||
CV_CheckEQ(layerParams.get<int>("keepdims"), 0, (layer_type + " layer only supports keepdims = false").c_str());
|
||||
LayerParams reshapeLp;
|
||||
reshapeLp.name = layerParams.name + "/reshape";
|
||||
reshapeLp.type = "Reshape";
|
||||
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||
int newShape[] = {1, 1, 1, -1};
|
||||
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(reshapeLp.name);
|
||||
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
|
||||
|
||||
LayerParams poolLp = layerParams;
|
||||
poolLp.name = layerParams.name + "/pool";
|
||||
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
|
||||
|
||||
node_proto.set_input(0, reshapeLp.name);
|
||||
node_proto.set_output(0, poolLp.name);
|
||||
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
|
||||
|
||||
layerParams.type = "Reshape";
|
||||
int targetShape[] = {1};
|
||||
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
|
||||
|
||||
node_proto.set_input(0, node_proto.output(0));
|
||||
node_proto.set_output(0, layerParams.name);
|
||||
}
|
||||
@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
|
||||
default: type = blob.type();
|
||||
}
|
||||
blob.convertTo(blob, type);
|
||||
addConstant(layerParams.name, blob, constBlobs, outShapes);
|
||||
Mat dst;
|
||||
blob.convertTo(dst, type);
|
||||
dst.dims = blob.dims;
|
||||
addConstant(layerParams.name, dst, constBlobs, outShapes);
|
||||
continue;
|
||||
}
|
||||
else
|
||||
@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
outShape.erase(outShape.begin() + axis);
|
||||
out.reshape(0, outShape);
|
||||
} else {
|
||||
out.dims = 1;
|
||||
}
|
||||
addConstant(layerParams.name, out, constBlobs, outShapes);
|
||||
continue;
|
||||
|
@ -262,6 +262,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
|
||||
testONNXModels("reduce_sum");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ReduceMaxGlobal)
|
||||
{
|
||||
testONNXModels("reduce_max");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ReduceMean3D)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
|
||||
|
Loading…
Reference in New Issue
Block a user