mirror of
https://github.com/opencv/opencv.git
synced 2025-07-25 22:57:53 +08:00
Added ONNX NormalizeL2 subgraph
This commit is contained in:
parent
0f968e3b6d
commit
68eb54dc13
@ -249,6 +249,40 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class NormalizeSubgraph4 : public NormalizeSubgraphBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NormalizeSubgraph4() : NormalizeSubgraphBase(1)
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int mul = addNodeToMatch("Mul", input, input);
|
||||||
|
int sum = addNodeToMatch("ReduceSum", mul);
|
||||||
|
int eps = addNodeToMatch("");
|
||||||
|
int max = addNodeToMatch("Max", sum, eps);
|
||||||
|
int sqrt = addNodeToMatch("Sqrt", max);
|
||||||
|
int reciprocal = addNodeToMatch("Reciprocal", sqrt);
|
||||||
|
addNodeToMatch("Mul", input, reciprocal);
|
||||||
|
setFusedNode("Normalize", input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NormalizeSubgraph5 : public NormalizeSubgraphBase
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
NormalizeSubgraph5() : NormalizeSubgraphBase(1)
|
||||||
|
{
|
||||||
|
int input = addNodeToMatch("");
|
||||||
|
int mul = addNodeToMatch("Mul", input, input);
|
||||||
|
int sum = addNodeToMatch("ReduceSum", mul);
|
||||||
|
int clip = addNodeToMatch("Clip", sum);
|
||||||
|
int sqrt = addNodeToMatch("Sqrt", clip);
|
||||||
|
int one = addNodeToMatch("Constant");
|
||||||
|
int div = addNodeToMatch("Div", one, sqrt);
|
||||||
|
addNodeToMatch("Mul", input, div);
|
||||||
|
setFusedNode("Normalize", input);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class GatherCastSubgraph : public Subgraph
|
class GatherCastSubgraph : public Subgraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
|||||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||||
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
||||||
subgraphs.push_back(makePtr<MishSubgraph>());
|
subgraphs.push_back(makePtr<MishSubgraph>());
|
||||||
|
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
|
||||||
|
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
|
||||||
|
|
||||||
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
|
||||||
}
|
}
|
||||||
|
@ -403,6 +403,11 @@ TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph)
|
|||||||
testONNXModels("batch_norm_subgraph");
|
testONNXModels("batch_norm_subgraph");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, NormalizeFusionSubgraph)
|
||||||
|
{
|
||||||
|
testONNXModels("normalize_fusion");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(Test_ONNX_layers, Transpose)
|
TEST_P(Test_ONNX_layers, Transpose)
|
||||||
{
|
{
|
||||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||||
|
Loading…
Reference in New Issue
Block a user