mirror of
https://github.com/opencv/opencv.git
synced 2025-07-30 09:16:50 +08:00
Merge pull request #18096 from l-bat:update_onnx_importer
* Added ReduceSum to ONNX importer * Fix comments * Fix Mul
This commit is contained in:
parent
3b5813c035
commit
ad63d24dba
@ -116,7 +116,6 @@ public:
|
|||||||
CV_CheckEQ(inputs.size(), (size_t)2, "");
|
CV_CheckEQ(inputs.size(), (size_t)2, "");
|
||||||
numOutput = inputs[1].back();
|
numOutput = inputs[1].back();
|
||||||
cAxis = inputs[0].size() - 1;
|
cAxis = inputs[0].size() - 1;
|
||||||
CV_CheckEQ(numOutput, inputs[0][cAxis - 1], "");
|
|
||||||
int dims = inputs[0].size();
|
int dims = inputs[0].size();
|
||||||
CV_CheckEQ(inputs[1].size(), (size_t)dims, "");
|
CV_CheckEQ(inputs[1].size(), (size_t)dims, "");
|
||||||
CV_CheckGE(dims, 2, "");
|
CV_CheckGE(dims, 2, "");
|
||||||
|
@ -262,6 +262,24 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ExpandSubgraph : public Subgraph
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ExpandSubgraph()
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int values = addNodeToMatch("");
|
||||||
|
int init = addNodeToMatch("ConstantOfShape", values);
|
||||||
|
int coeff = addNodeToMatch("Constant");
|
||||||
|
int mul = addNodeToMatch("Mul", init, coeff);
|
||||||
|
int shape = addNodeToMatch("Constant");
|
||||||
|
int condition = addNodeToMatch("Equal", shape, mul);
|
||||||
|
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
|
||||||
|
addNodeToMatch("Expand", input, where);
|
||||||
|
setFusedNode("Expand", input, shape);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class MulCastSubgraph : public Subgraph
|
class MulCastSubgraph : public Subgraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -459,6 +477,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
|||||||
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
|
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
|
||||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
|
||||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||||
|
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
||||||
|
|
||||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||||
}
|
}
|
||||||
|
@ -387,26 +387,42 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
layerParams.set("ceil_mode", layerParams.has("pad_mode"));
|
layerParams.set("ceil_mode", layerParams.has("pad_mode"));
|
||||||
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
|
||||||
}
|
}
|
||||||
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
|
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
|
||||||
|
layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
||||||
{
|
{
|
||||||
CV_Assert(node_proto.input_size() == 1);
|
CV_Assert(node_proto.input_size() == 1);
|
||||||
layerParams.type = "Pooling";
|
layerParams.type = "Pooling";
|
||||||
layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
|
String pool;
|
||||||
|
if (layer_type == "GlobalMaxPool")
|
||||||
|
pool = "MAX";
|
||||||
|
else if (layer_type == "ReduceSum")
|
||||||
|
pool = "SUM";
|
||||||
|
else
|
||||||
|
pool = "AVE";
|
||||||
|
layerParams.set("pool", pool);
|
||||||
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
|
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
|
||||||
|
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
|
||||||
if (layer_type == "ReduceMean")
|
|
||||||
{
|
{
|
||||||
if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
|
if (!layerParams.has("axes"))
|
||||||
CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
|
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
|
||||||
|
|
||||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||||
DictValue axes = layerParams.get("axes");
|
DictValue axes = layerParams.get("axes");
|
||||||
|
bool keepdims = layerParams.get<int>("keepdims");
|
||||||
|
MatShape targetShape = inpShape;
|
||||||
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
|
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||||
|
if (keepdims) {
|
||||||
|
targetShape[axis] = 1;
|
||||||
|
} else {
|
||||||
|
targetShape.erase(targetShape.begin() + axis);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (inpShape.size() == 3 && axes.size() <= 2)
|
if (inpShape.size() == 3 && axes.size() <= 2)
|
||||||
{
|
{
|
||||||
int axis = axes.get<int>(0);
|
int axis = clamp(axes.get<int>(0), inpShape.size());
|
||||||
CV_CheckNE(axis, 0, "");
|
CV_CheckNE(axis, 0, "");
|
||||||
outShapes[layerParams.name] = inpShape;
|
|
||||||
outShapes[layerParams.name][axis] = 1;
|
|
||||||
|
|
||||||
LayerParams reshapeLp;
|
LayerParams reshapeLp;
|
||||||
reshapeLp.name = layerParams.name + "/reshape";
|
reshapeLp.name = layerParams.name + "/reshape";
|
||||||
@ -426,13 +442,12 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
avgLp.name = layerParams.name + "/avg";
|
avgLp.name = layerParams.name + "/avg";
|
||||||
avgLp.type = "Pooling";
|
avgLp.type = "Pooling";
|
||||||
CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
|
CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
|
||||||
avgLp.set("pool", "ave");
|
avgLp.set("pool", pool);
|
||||||
if (axes.size() == 2)
|
if (axes.size() == 2)
|
||||||
{
|
{
|
||||||
CV_CheckEQ(axes.get<int>(0), 1, "Unsupported ReduceMean mode");
|
CV_CheckEQ(clamp(axes.get<int>(0), inpShape.size()), 1, ("Unsupported " + layer_type + " mode").c_str());
|
||||||
CV_CheckEQ(axes.get<int>(1), 2, "Unsupported ReduceMean mode");
|
CV_CheckEQ(clamp(axes.get<int>(1), inpShape.size()), 2, ("Unsupported " + layer_type + " mode").c_str());
|
||||||
avgLp.set("global_pooling", true);
|
avgLp.set("global_pooling", true);
|
||||||
outShapes[layerParams.name][axes.get<int>(1)] = 1;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -443,28 +458,33 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
node_proto.set_input(0, reshapeLp.name);
|
node_proto.set_input(0, reshapeLp.name);
|
||||||
node_proto.set_output(0, avgLp.name);
|
node_proto.set_output(0, avgLp.name);
|
||||||
addLayer(dstNet, avgLp, node_proto, layer_id, outShapes);
|
addLayer(dstNet, avgLp, node_proto, layer_id, outShapes);
|
||||||
|
|
||||||
layerParams.type = "Flatten";
|
|
||||||
layerParams.set("axis", 0);
|
|
||||||
layerParams.set("end_axis", 1);
|
|
||||||
|
|
||||||
node_proto.set_input(0, avgLp.name);
|
|
||||||
node_proto.set_output(0, layerParams.name);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
if (inpShape.size() != 4 && inpShape.size() != 5)
|
if (inpShape.size() != 4 && inpShape.size() != 5)
|
||||||
CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
|
CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
|
||||||
|
|
||||||
CV_Assert(axes.size() <= inpShape.size() - 2);
|
CV_Assert(axes.size() <= inpShape.size() - 2);
|
||||||
std::vector<int> kernel_size(inpShape.size() - 2, 1);
|
std::vector<int> kernel_size(inpShape.size() - 2, 1);
|
||||||
for (int i = 0; i < axes.size(); i++) {
|
for (int i = 0; i < axes.size(); i++) {
|
||||||
int axis = axes.get<int>(i);
|
int axis = clamp(axes.get<int>(i), inpShape.size());
|
||||||
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
|
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
|
||||||
kernel_size[axis - 2] = inpShape[axis];
|
kernel_size[axis - 2] = inpShape[axis];
|
||||||
}
|
}
|
||||||
layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
|
LayerParams poolLp = layerParams;
|
||||||
|
poolLp.name = layerParams.name + "/avg";
|
||||||
|
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
|
||||||
|
poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
|
||||||
|
|
||||||
|
node_proto.set_output(0, poolLp.name);
|
||||||
|
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
layerParams.type = "Reshape";
|
||||||
|
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
|
||||||
|
|
||||||
|
node_proto.set_input(0, node_proto.output(0));
|
||||||
|
node_proto.set_output(0, layerParams.name);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (layer_type == "Slice")
|
else if (layer_type == "Slice")
|
||||||
@ -1001,15 +1021,10 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
{
|
{
|
||||||
Mat inp0 = getBlob(node_proto, constBlobs, 0);
|
Mat inp0 = getBlob(node_proto, constBlobs, 0);
|
||||||
Mat inp1 = getBlob(node_proto, constBlobs, 1);
|
Mat inp1 = getBlob(node_proto, constBlobs, 1);
|
||||||
if (inp0.size != inp1.size)
|
if (inp0.size != inp1.size && inp1.total() != 1)
|
||||||
CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
|
CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
|
||||||
|
|
||||||
Mat out;
|
Mat out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
|
||||||
if (isDiv)
|
|
||||||
divide(inp0, inp1, out);
|
|
||||||
else
|
|
||||||
multiply(inp0, inp1, out);
|
|
||||||
|
|
||||||
out = out.reshape(1, inp0.dims, inp0.size);
|
out = out.reshape(1, inp0.dims, inp0.size);
|
||||||
out.dims = inp0.dims; // to workaround dims == 1
|
out.dims = inp0.dims; // to workaround dims == 1
|
||||||
addConstant(layerParams.name, out, constBlobs, outShapes);
|
addConstant(layerParams.name, out, constBlobs, outShapes);
|
||||||
@ -1180,9 +1195,45 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
Mat newShapeMat = getBlob(node_proto, constBlobs, 1);
|
Mat newShapeMat = getBlob(node_proto, constBlobs, 1);
|
||||||
MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
|
MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
|
||||||
|
|
||||||
shapeIt = outShapes.find(node_proto.input(0));
|
MatShape inpShape;
|
||||||
CV_Assert(shapeIt != outShapes.end());
|
bool haveVariables = constBlobs.find(node_proto.input(0)) == constBlobs.end();
|
||||||
MatShape inpShape = shapeIt->second;
|
if (haveVariables)
|
||||||
|
{
|
||||||
|
shapeIt = outShapes.find(node_proto.input(0));
|
||||||
|
CV_Assert(shapeIt != outShapes.end());
|
||||||
|
inpShape = shapeIt->second;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
inpShape = shape(getBlob(node_proto, constBlobs, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
String srcName = node_proto.input(0);
|
||||||
|
// Unsqueeze and repeat along new axis
|
||||||
|
if (targetShape.size() == inpShape.size() + 1)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < targetShape.size(); i++)
|
||||||
|
{
|
||||||
|
if (targetShape[i] == -1 && i < inpShape.size())
|
||||||
|
targetShape[i] = inpShape[i];
|
||||||
|
else if (i < inpShape.size() && targetShape[i] != inpShape[i])
|
||||||
|
inpShape.insert(inpShape.begin() + i, 1);
|
||||||
|
}
|
||||||
|
if (haveVariables)
|
||||||
|
{
|
||||||
|
LayerParams reshapeLp;
|
||||||
|
reshapeLp.name = layerParams.name + "/reshape";
|
||||||
|
reshapeLp.type = "Reshape";
|
||||||
|
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||||
|
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||||
|
|
||||||
|
opencv_onnx::NodeProto proto;
|
||||||
|
proto.add_input(node_proto.input(0));
|
||||||
|
proto.add_output(reshapeLp.name);
|
||||||
|
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
|
||||||
|
srcName = reshapeLp.name;
|
||||||
|
}
|
||||||
|
}
|
||||||
CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
|
CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
|
||||||
|
|
||||||
std::vector<int> broadcast_axes;
|
std::vector<int> broadcast_axes;
|
||||||
@ -1197,6 +1248,19 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!haveVariables)
|
||||||
|
{
|
||||||
|
if (broadcast_axes.size() != 1)
|
||||||
|
CV_Error(Error::StsNotImplemented, "Expand op doesn't support multiple axes for constant input");
|
||||||
|
|
||||||
|
Mat input = getBlob(node_proto, constBlobs, 0);
|
||||||
|
input = input.reshape(0, total(inpShape, 0, broadcast_axes[0]));
|
||||||
|
Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]);
|
||||||
|
output = output.reshape(0, targetShape);
|
||||||
|
addConstant(layerParams.name, output, constBlobs, outShapes);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (broadcast_axes.size() == 2 &&
|
if (broadcast_axes.size() == 2 &&
|
||||||
broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
|
broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
|
||||||
{
|
{
|
||||||
@ -1231,6 +1295,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
|
CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
|
||||||
input_names.push_back(copyLP.name);
|
input_names.push_back(copyLP.name);
|
||||||
|
|
||||||
|
node_proto.set_input(0, srcName);
|
||||||
node_proto.set_output(0, copyLP.name);
|
node_proto.set_output(0, copyLP.name);
|
||||||
addLayer(dstNet, copyLP, node_proto, layer_id, outShapes);
|
addLayer(dstNet, copyLP, node_proto, layer_id, outShapes);
|
||||||
}
|
}
|
||||||
@ -1241,6 +1306,7 @@ void ONNXImporter::populateNet(Net dstNet)
|
|||||||
}
|
}
|
||||||
layerParams.set("axis", broadcast_axes[0]);
|
layerParams.set("axis", broadcast_axes[0]);
|
||||||
layerParams.type = "Concat";
|
layerParams.type = "Concat";
|
||||||
|
node_proto.set_output(0, layerParams.name);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
|
CV_Error(Error::StsNotImplemented, "Unsupported Expand op");
|
||||||
|
@ -257,6 +257,11 @@ TEST_P(Test_ONNX_layers, ReduceMean)
|
|||||||
testONNXModels("reduce_mean_axis2");
|
testONNXModels("reduce_mean_axis2");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, ReduceSum)
|
||||||
|
{
|
||||||
|
testONNXModels("reduce_sum");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, ReduceMean3D)
|
TEST_P(Test_ONNX_layers, ReduceMean3D)
|
||||||
{
|
{
|
||||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
|
||||||
@ -417,6 +422,7 @@ TEST_P(Test_ONNX_layers, Expand)
|
|||||||
{
|
{
|
||||||
testONNXModels("expand_batch");
|
testONNXModels("expand_batch");
|
||||||
testONNXModels("expand_channels");
|
testONNXModels("expand_channels");
|
||||||
|
testONNXModels("expand_neg_batch");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, ExpandHW)
|
TEST_P(Test_ONNX_layers, ExpandHW)
|
||||||
|
Loading…
Reference in New Issue
Block a user