replace new mish impl with softplus

This commit is contained in:
Zihao Mu 2022-07-28 13:19:06 +08:00
parent 3c5377ca1b
commit 57545653b1

View File

@ -531,35 +531,32 @@ public:
} }
}; };
class MishSubgraph2 : public Subgraph // softplus(x) = log(exp(x) + 1)
class SoftplusSubgraph: public Subgraph
{ {
public: public:
MishSubgraph2() SoftplusSubgraph()
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input); int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch(""); int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", addVal, exp); int add = addNodeToMatch("Add", addVal, exp);
int log = addNodeToMatch("Log", add); addNodeToMatch("Log", add);
int tanh = addNodeToMatch("Tanh", log); setFusedNode("Softplus", input);
addNodeToMatch("Mul", input, tanh);
setFusedNode("Mish", input);
} }
}; };
class MishSubgraph3 : public Subgraph class SoftplusSubgraph2: public Subgraph
{ {
public: public:
MishSubgraph3() SoftplusSubgraph2()
{ {
int input = addNodeToMatch(""); int input = addNodeToMatch("");
int exp = addNodeToMatch("Exp", input); int exp = addNodeToMatch("Exp", input);
int addVal = addNodeToMatch(""); int addVal = addNodeToMatch("");
int add = addNodeToMatch("Add", exp, addVal); int add = addNodeToMatch("Add", exp, addVal);
int log = addNodeToMatch("Log", add); addNodeToMatch("Log", add);
int tanh = addNodeToMatch("Tanh", log); setFusedNode("Softplus", input);
addNodeToMatch("Mul", input, tanh);
setFusedNode("Mish", input);
} }
}; };
@ -766,9 +763,9 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>()); subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>()); subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
subgraphs.push_back(makePtr<ExpandSubgraph>()); subgraphs.push_back(makePtr<ExpandSubgraph>());
subgraphs.push_back(makePtr<SoftplusSubgraph>());
subgraphs.push_back(makePtr<SoftplusSubgraph2>());
subgraphs.push_back(makePtr<MishSubgraph>()); subgraphs.push_back(makePtr<MishSubgraph>());
subgraphs.push_back(makePtr<MishSubgraph2>());
subgraphs.push_back(makePtr<MishSubgraph3>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>()); subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>()); subgraphs.push_back(makePtr<NormalizeSubgraph5>());