Added ONNX NormalizeL2 subgraph

This commit is contained in:
Liubov Batanina 2021-02-01 12:38:33 +03:00
parent 0f968e3b6d
commit 68eb54dc13
2 changed files with 41 additions and 0 deletions

View File

@ -249,6 +249,40 @@ public:
}
};
class NormalizeSubgraph4 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph4() : NormalizeSubgraphBase(1)
{
int input = addNodeToMatch("");
int mul = addNodeToMatch("Mul", input, input);
int sum = addNodeToMatch("ReduceSum", mul);
int eps = addNodeToMatch("");
int max = addNodeToMatch("Max", sum, eps);
int sqrt = addNodeToMatch("Sqrt", max);
int reciprocal = addNodeToMatch("Reciprocal", sqrt);
addNodeToMatch("Mul", input, reciprocal);
setFusedNode("Normalize", input);
}
};
class NormalizeSubgraph5 : public NormalizeSubgraphBase
{
public:
NormalizeSubgraph5() : NormalizeSubgraphBase(1)
{
int input = addNodeToMatch("");
int mul = addNodeToMatch("Mul", input, input);
int sum = addNodeToMatch("ReduceSum", mul);
int clip = addNodeToMatch("Clip", sum);
int sqrt = addNodeToMatch("Sqrt", clip);
int one = addNodeToMatch("Constant");
int div = addNodeToMatch("Div", one, sqrt);
addNodeToMatch("Mul", input, div);
setFusedNode("Normalize", input);
}
};
class GatherCastSubgraph : public Subgraph
{
public:
@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<MishSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
}

View File

@ -403,6 +403,11 @@ TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph)
testONNXModels("batch_norm_subgraph");
}
TEST_P(Test_ONNX_layers, NormalizeFusionSubgraph)
{
testONNXModels("normalize_fusion");
}
TEST_P(Test_ONNX_layers, Transpose)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)