mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Merge pull request #17233 from l-bat:onnx_bn
* Added ONNX BatchNorm subgraph * Move removing constant inputs to addConstantNodesForInitializers * Added initializers to ONNXGraphWrapper
This commit is contained in:
parent
1bf353b876
commit
79f8b7fd73
@ -61,27 +61,28 @@ public:
|
||||
ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net)
|
||||
{
|
||||
numInputs = net.input_size();
|
||||
numInitializers = net.initializer_size();
|
||||
}
|
||||
|
||||
virtual Ptr<ImportNodeWrapper> getNode(int idx) const CV_OVERRIDE
|
||||
{
|
||||
opencv_onnx::NodeProto* node = 0;
|
||||
if (idx >= numInputs)
|
||||
node = net.mutable_node(idx - numInputs);
|
||||
if (idx >= numInputs + numInitializers)
|
||||
node = net.mutable_node(idx - numInputs - numInitializers);
|
||||
return makePtr<ONNXNodeWrapper>(node);
|
||||
}
|
||||
|
||||
virtual int getNumNodes() const CV_OVERRIDE
|
||||
{
|
||||
return numInputs + net.node_size();
|
||||
return numInputs + numInitializers + net.node_size();
|
||||
}
|
||||
|
||||
virtual int getNumOutputs(int nodeId) const CV_OVERRIDE
|
||||
{
|
||||
if (nodeId < numInputs)
|
||||
if (nodeId < numInputs + numInitializers)
|
||||
return 1;
|
||||
else
|
||||
return net.node(nodeId - numInputs).output_size();
|
||||
return net.node(nodeId - numInputs - numInitializers).output_size();
|
||||
}
|
||||
|
||||
virtual std::string getOutputName(int nodeId, int outId) const CV_OVERRIDE
|
||||
@ -89,18 +90,20 @@ public:
|
||||
CV_Assert(outId < getNumOutputs(nodeId));
|
||||
if (nodeId < numInputs)
|
||||
return net.input(nodeId).name();
|
||||
else if (nodeId < numInputs + numInitializers)
|
||||
return net.initializer(nodeId - numInputs).name();
|
||||
else
|
||||
return net.node(nodeId - numInputs).output(outId);
|
||||
return net.node(nodeId - numInputs - numInitializers).output(outId);
|
||||
}
|
||||
|
||||
virtual void removeNode(int idx) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(idx >= numInputs);
|
||||
net.mutable_node()->DeleteSubrange(idx - numInputs, 1);
|
||||
CV_Assert(idx >= numInputs + numInitializers);
|
||||
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
|
||||
}
|
||||
|
||||
private:
|
||||
int numInputs;
|
||||
int numInputs, numInitializers;
|
||||
opencv_onnx::GraphProto& net;
|
||||
};
|
||||
|
||||
@ -382,33 +385,63 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class BatchNormalizationSubgraph : public Subgraph
|
||||
class BatchNormalizationSubgraphBase : public Subgraph
|
||||
{
|
||||
public:
|
||||
BatchNormalizationSubgraph()
|
||||
BatchNormalizationSubgraphBase()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int data1 = addNodeToMatch("Constant");
|
||||
int data2 = addNodeToMatch("Constant");
|
||||
int data3 = addNodeToMatch("Constant");
|
||||
int data4 = addNodeToMatch("Constant");
|
||||
int shape1 = addNodeToMatch("Constant");
|
||||
int reshape1 = addNodeToMatch("Reshape", data1, shape1);
|
||||
int shape2 = addNodeToMatch("Constant");
|
||||
int reshape2 = addNodeToMatch("Reshape", data2, shape2);
|
||||
input = addNodeToMatch("");
|
||||
var = addNodeToMatch("");
|
||||
mean = addNodeToMatch("");
|
||||
weight = addNodeToMatch("");
|
||||
bias = addNodeToMatch("");
|
||||
A = addNodeToMatch("");
|
||||
shape1 = addNodeToMatch("");
|
||||
shape2 = addNodeToMatch("");
|
||||
}
|
||||
protected:
|
||||
int input, var, mean, weight, bias, A, shape1, shape2;
|
||||
};
|
||||
|
||||
class BatchNormalizationSubgraph1 : public BatchNormalizationSubgraphBase
|
||||
{
|
||||
public:
|
||||
BatchNormalizationSubgraph1()
|
||||
{
|
||||
int reshape1 = addNodeToMatch("Reshape", weight, shape1);
|
||||
int reshape2 = addNodeToMatch("Reshape", bias, shape2);
|
||||
int shape3 = addNodeToMatch("Constant");
|
||||
int reshape3 = addNodeToMatch("Reshape", data3, shape3);
|
||||
int reshape3 = addNodeToMatch("Reshape", var, shape3);
|
||||
int shape4 = addNodeToMatch("Constant");
|
||||
int reshape4 = addNodeToMatch("Reshape", data4, shape4);
|
||||
int reshape4 = addNodeToMatch("Reshape", mean, shape4);
|
||||
int sqrtNode = addNodeToMatch("Sqrt", reshape3);
|
||||
int A = addNodeToMatch("Constant");
|
||||
int divNode = addNodeToMatch("Div", A, sqrtNode);
|
||||
int mul1 = addNodeToMatch("Mul", reshape1, divNode);
|
||||
int mul2 = addNodeToMatch("Mul", reshape4, mul1);
|
||||
int sub = addNodeToMatch("Sub", reshape2, mul2);
|
||||
int mul3 = addNodeToMatch("Mul", input, mul1);
|
||||
addNodeToMatch("Add", mul3, sub);
|
||||
setFusedNode("BatchNormalization", input, data1, data2, data4 ,data3);
|
||||
setFusedNode("BatchNormalization", input, weight, bias, mean, var);
|
||||
}
|
||||
};
|
||||
|
||||
class BatchNormalizationSubgraph2 : public BatchNormalizationSubgraphBase
|
||||
{
|
||||
public:
|
||||
BatchNormalizationSubgraph2()
|
||||
{
|
||||
int sqrtNode = addNodeToMatch("Sqrt", var);
|
||||
int divNode = addNodeToMatch("Div", A, sqrtNode);
|
||||
int mul1 = addNodeToMatch("Mul", weight, divNode);
|
||||
int reshape2 = addNodeToMatch("Reshape", mul1, shape2);
|
||||
|
||||
int mulMean = addNodeToMatch("Mul", mean, mul1);
|
||||
int sub = addNodeToMatch("Sub", bias, mulMean);
|
||||
int reshape1 = addNodeToMatch("Reshape", sub, shape1);
|
||||
|
||||
int mulInput = addNodeToMatch("Mul", input, reshape2);
|
||||
addNodeToMatch("Add", mulInput, reshape1);
|
||||
setFusedNode("BatchNormalization", input, weight, bias, mean, var);
|
||||
}
|
||||
};
|
||||
|
||||
@ -424,7 +457,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph>());
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||
|
||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||
}
|
||||
|
@ -309,30 +309,11 @@ static void addConstant(const std::string& name,
|
||||
outShapes.insert(std::make_pair(name, shape(blob)));
|
||||
}
|
||||
|
||||
void addConstantNodesForInitializers(opencv_onnx::GraphProto& graph_proto)
|
||||
{
|
||||
int num_initializers = graph_proto.initializer_size();
|
||||
for (int id = 0; id < num_initializers; id++)
|
||||
{
|
||||
opencv_onnx::TensorProto initializer = graph_proto.initializer(id);
|
||||
opencv_onnx::NodeProto* constant_node = graph_proto.add_node();
|
||||
constant_node->set_op_type("Constant");
|
||||
constant_node->set_name(initializer.name());
|
||||
constant_node->add_output(initializer.name());
|
||||
opencv_onnx::AttributeProto* value = constant_node->add_attribute();
|
||||
opencv_onnx::TensorProto* tensor = initializer.New();
|
||||
tensor->CopyFrom(initializer);
|
||||
releaseONNXTensor(initializer);
|
||||
value->set_allocated_t(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
CV_Assert(model_proto.has_graph());
|
||||
opencv_onnx::GraphProto graph_proto = model_proto.graph();
|
||||
|
||||
addConstantNodesForInitializers(graph_proto);
|
||||
simplifySubgraphs(graph_proto);
|
||||
|
||||
std::map<std::string, Mat> constBlobs = getGraphTensors(graph_proto);
|
||||
|
@ -306,6 +306,13 @@ TEST_P(Test_ONNX_layers, BatchNormalizationUnfused)
|
||||
testONNXModels("frozenBatchNorm2d");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
|
||||
testONNXModels("batch_norm_subgraph");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Transpose)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
|
Loading…
Reference in New Issue
Block a user