Support global reduce ops

This commit is contained in:
Liubov Batanina 2020-09-09 10:40:02 +03:00
parent 6b674709b8
commit b542a1804c
2 changed files with 45 additions and 9 deletions

View File

@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
}
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
layer_type == "ReduceMean" || layer_type == "ReduceSum")
layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
{
CV_Assert(node_proto.input_size() == 1);
layerParams.type = "Pooling";
String pool;
if (layer_type == "GlobalMaxPool")
if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
pool = "MAX";
else if (layer_type == "ReduceSum")
pool = "SUM";
else
pool = "AVE";
layerParams.set("pool", pool);
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
layerParams.set("global_pooling", !layerParams.has("axes"));
if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
{
if (!layerParams.has("axes"))
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
MatShape inpShape = outShapes[node_proto.input(0)];
DictValue axes = layerParams.get("axes");
bool keepdims = layerParams.get<int>("keepdims");
@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "Reshape";
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name);
}
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
{
CV_CheckEQ(layerParams.get<int>("keepdims"), 0, (layer_type + " layer only supports keepdims = false").c_str());
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
int newShape[] = {1, 1, 1, -1};
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
LayerParams poolLp = layerParams;
poolLp.name = layerParams.name + "/pool";
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
node_proto.set_input(0, reshapeLp.name);
node_proto.set_output(0, poolLp.name);
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
layerParams.type = "Reshape";
int targetShape[] = {1};
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
node_proto.set_input(0, node_proto.output(0));
node_proto.set_output(0, layerParams.name);
}
@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet)
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
default: type = blob.type();
}
blob.convertTo(blob, type);
addConstant(layerParams.name, blob, constBlobs, outShapes);
Mat dst;
blob.convertTo(dst, type);
dst.dims = blob.dims;
addConstant(layerParams.name, dst, constBlobs, outShapes);
continue;
}
else
@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet)
{
outShape.erase(outShape.begin() + axis);
out.reshape(0, outShape);
} else {
out.dims = 1;
}
addConstant(layerParams.name, out, constBlobs, outShapes);
continue;

View File

@ -262,6 +262,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
testONNXModels("reduce_sum");
}
TEST_P(Test_ONNX_layers, ReduceMaxGlobal)
{
testONNXModels("reduce_max");
}
TEST_P(Test_ONNX_layers, ReduceMean3D)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)