mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Merge pull request #18096 from l-bat:update_onnx_importer
* Added ReduceSum to ONNX importer * Fix comments * Fix Mul
This commit is contained in:
parent
3b5813c035
commit
ad63d24dba
@ -116,7 +116,6 @@ public:
|
||||
CV_CheckEQ(inputs.size(), (size_t)2, "");
|
||||
numOutput = inputs[1].back();
|
||||
cAxis = inputs[0].size() - 1;
|
||||
CV_CheckEQ(numOutput, inputs[0][cAxis - 1], "");
|
||||
int dims = inputs[0].size();
|
||||
CV_CheckEQ(inputs[1].size(), (size_t)dims, "");
|
||||
CV_CheckGE(dims, 2, "");
|
||||
|
@ -262,6 +262,24 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ExpandSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
ExpandSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int values = addNodeToMatch("");
|
||||
int init = addNodeToMatch("ConstantOfShape", values);
|
||||
int coeff = addNodeToMatch("Constant");
|
||||
int mul = addNodeToMatch("Mul", init, coeff);
|
||||
int shape = addNodeToMatch("Constant");
|
||||
int condition = addNodeToMatch("Equal", shape, mul);
|
||||
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
|
||||
addNodeToMatch("Expand", input, where);
|
||||
setFusedNode("Expand", input, shape);
|
||||
}
|
||||
};
|
||||
|
||||
class MulCastSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
@ -459,6 +477,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
||||
|
||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||
}
|
||||
|
@ -387,26 +387,42 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
layerParams.set("ceil_mode", layerParams.has("pad_mode"));
|
||||
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
||||
}
|
||||
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
|
||||
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
|
||||
layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
||||
{
|
||||
CV_Assert(node_proto.input_size() == 1);
|
||||
layerParams.type = "Pooling";
|
||||
layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
|
||||
String pool;
|
||||
if (layer_type == "GlobalMaxPool")
|
||||
pool = "MAX";
|
||||
else if (layer_type == "ReduceSum")
|
||||
pool = "SUM";
|
||||
else
|
||||
pool = "AVE";
|
||||
layerParams.set("pool", pool);
|
||||
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
|
||||
|
||||
if (layer_type == "ReduceMean")
|
||||
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
||||
{
|
||||
if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
|
||||
if (!layerParams.has("axes"))
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
|
||||
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
DictValue axes = layerParams.get("axes");
|
||||
bool keepdims = layerParams.get<int>("keepdims");
|
||||
MatShape targetShape = inpShape;
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||
if (keepdims) {
|
||||
targetShape[axis] = 1;
|
||||
} else {
|
||||
targetShape.erase(targetShape.begin() + axis);
|
||||
}
|
||||
}
|
||||
|
||||
if (inpShape.size() == 3 && axes.size() <= 2)
|
||||
{
|
||||
int axis = axes.get<int>(0);
|
||||
int axis = clamp(axes.get<int>(0), inpShape.size());
|
||||
CV_CheckNE(axis, 0, "");
|
||||
outShapes[layerParams.name] = inpShape;
|
||||
outShapes[layerParams.name][axis] = 1;
|
||||
|
||||
LayerParams reshapeLp;
|
||||
reshapeLp.name = layerParams.name + "/reshape";
|
||||
@ -426,13 +442,12 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
avgLp.name = layerParams.name + "/avg";
|
||||
avgLp.type = "Pooling";
|
||||
CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
|
||||
avgLp.set("pool", "ave");
|
||||
avgLp.set("pool", pool);
|
||||
if (axes.size() == 2)
|
||||
{
|
||||
CV_CheckEQ(axes.get<int>(0), 1, "Unsupported ReduceMean mode");
|
||||
CV_CheckEQ(axes.get<int>(1), 2, "Unsupported ReduceMean mode");
|
||||
CV_CheckEQ(clamp(axes.get<int>(0), inpShape.size()), 1, ("Unsupported " + layer_type + " mode").c_str());
|
||||
CV_CheckEQ(clamp(axes.get<int>(1), inpShape.size()), 2, ("Unsupported " + layer_type + " mode").c_str());
|
||||
avgLp.set("global_pooling", true);
|
||||
outShapes[layerParams.name][axes.get<int>(1)] = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -443,28 +458,33 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
node_proto.set_input(0, reshapeLp.name);
|
||||
node_proto.set_output(0, avgLp.name);
|
||||
addLayer(dstNet, avgLp, node_proto, layer_id, outShapes);
|
||||
|
||||
layerParams.type = "Flatten";
|
||||
layerParams.set("axis", 0);
|
||||
layerParams.set("end_axis", 1);
|
||||
|
||||
node_proto.set_input(0, avgLp.name);
|
||||
node_proto.set_output(0, layerParams.name);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (inpShape.size() != 4 && inpShape.size() != 5)
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
|
||||
|
||||
CV_Assert(axes.size() <= inpShape.size() - 2);
|
||||
std::vector<int> kernel_size(inpShape.size() - 2, 1);
|
||||
for (int i = 0; i < axes.size(); i++) {
|
||||
int axis = axes.get<int>(i);
|
||||
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
|
||||
kernel_size[axis - 2] = inpShape[axis];
|
||||
}
|
||||
layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
|
||||
LayerParams poolLp = layerParams;
|
||||
poolLp.name = layerParams.name + "/avg";
|
||||
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
|
||||
poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
|
||||
|
||||
node_proto.set_output(0, poolLp.name);
|
||||
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
|
||||
}
|
||||
|
||||
layerParams.type = "Reshape";
|
||||
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
|
||||
|
||||
node_proto.set_input(0, node_proto.output(0));
|
||||
node_proto.set_output(0, layerParams.name);
|
||||
}
|
||||
}
|
||||
else if (layer_type == "Slice")
|
||||
@ -1001,15 +1021,10 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
Mat inp0 = getBlob(node_proto, constBlobs, 0);
|
||||
Mat inp1 = getBlob(node_proto, constBlobs, 1);
|
||||
if (inp0.size != inp1.size)
|
||||
if (inp0.size != inp1.size && inp1.total() != 1)
|
||||
CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
|
||||
|
||||
Mat out;
|
||||
if (isDiv)
|
||||
divide(inp0, inp1, out);
|
||||
else
|
||||
multiply(inp0, inp1, out);
|
||||
|
||||
Mat out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
|
||||
out = out.reshape(1, inp0.dims, inp0.size);
|
||||
out.dims = inp0.dims; // to workaround dims == 1
|
||||
addConstant(layerParams.name, out, constBlobs, outShapes);
|
||||
@ -1180,9 +1195,45 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
Mat newShapeMat = getBlob(node_proto, constBlobs, 1);
|
||||
MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
|
||||
|
||||
shapeIt = outShapes.find(node_proto.input(0));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
MatShape inpShape = shapeIt->second;
|
||||
MatShape inpShape;
|
||||
bool haveVariables = constBlobs.find(node_proto.input(0)) == constBlobs.end();
|
||||
if (haveVariables)
|
||||
{
|
||||
shapeIt = outShapes.find(node_proto.input(0));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
inpShape = shapeIt->second;
|
||||
}
|
||||
else
|
||||
{
|
||||
inpShape = shape(getBlob(node_proto, constBlobs, 0));
|
||||
}
|
||||
|
||||
String srcName = node_proto.input(0);
|
||||
// Unsqueeze and repeat along new axis
|
||||
if (targetShape.size() == inpShape.size() + 1)
|
||||
{
|
||||
for (int i = 0; i < targetShape.size(); i++)
|
||||
{
|
||||
if (targetShape[i] == -1 && i < inpShape.size())
|
||||
targetShape[i] = inpShape[i];
|
||||
else if (i < inpShape.size() && targetShape[i] != inpShape[i])
|
||||
inpShape.insert(inpShape.begin() + i, 1);
|
||||
}
|
||||
if (haveVariables)
|
||||
{
|
||||
LayerParams reshapeLp;
|
||||
reshapeLp.name = layerParams.name + "/reshape";
|
||||
reshapeLp.type = "Reshape";
|
||||
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(reshapeLp.name);
|
||||
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
|
||||
srcName = reshapeLp.name;
|
||||
}
|
||||
}
|
||||
CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
|
||||
|
||||
std::vector<int> broadcast_axes;
|
||||
@ -1197,6 +1248,19 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
}
|
||||
}
|
||||
|
||||
if (!haveVariables)
|
||||
{
|
||||
if (broadcast_axes.size() != 1)
|
||||
CV_Error(Error::StsNotImplemented, "Expand op doesn't support multiple axes for constant input");
|
||||
|
||||
Mat input = getBlob(node_proto, constBlobs, 0);
|
||||
input = input.reshape(0, total(inpShape, 0, broadcast_axes[0]));
|
||||
Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]);
|
||||
output = output.reshape(0, targetShape);
|
||||
addConstant(layerParams.name, output, constBlobs, outShapes);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (broadcast_axes.size() == 2 &&
|
||||
broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
|
||||
{
|
||||
@ -1231,6 +1295,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
|
||||
input_names.push_back(copyLP.name);
|
||||
|
||||
node_proto.set_input(0, srcName);
|
||||
node_proto.set_output(0, copyLP.name);
|
||||
addLayer(dstNet, copyLP, node_proto, layer_id, outShapes);
|
||||
}
|
||||
@ -1241,6 +1306,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
}
|
||||
layerParams.set("axis", broadcast_axes[0]);
|
||||
layerParams.type = "Concat";
|
||||
node_proto.set_output(0, layerParams.name);
|
||||
}
|
||||
else
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
|
||||
|
@ -257,6 +257,11 @@ TEST_P(Test_ONNX_layers, ReduceMean)
|
||||
testONNXModels("reduce_mean_axis2");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ReduceSum)
|
||||
{
|
||||
testONNXModels("reduce_sum");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ReduceMean3D)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
|
||||
@ -417,6 +422,7 @@ TEST_P(Test_ONNX_layers, Expand)
|
||||
{
|
||||
testONNXModels("expand_batch");
|
||||
testONNXModels("expand_channels");
|
||||
testONNXModels("expand_neg_batch");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, ExpandHW)
|
||||
|
Loading…
Reference in New Issue
Block a user