diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 5aad1c135c..091d2d4ae9 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -531,35 +531,32 @@ public: } }; -class MishSubgraph2 : public Subgraph +// softplus(x) = log(exp(x) + 1) +class SoftplusSubgraph: public Subgraph { public: - MishSubgraph2() + SoftplusSubgraph() { int input = addNodeToMatch(""); int exp = addNodeToMatch("Exp", input); int addVal = addNodeToMatch(""); int add = addNodeToMatch("Add", addVal, exp); - int log = addNodeToMatch("Log", add); - int tanh = addNodeToMatch("Tanh", log); - addNodeToMatch("Mul", input, tanh); - setFusedNode("Mish", input); + addNodeToMatch("Log", add); + setFusedNode("Softplus", input); } }; -class MishSubgraph3 : public Subgraph +class SoftplusSubgraph2: public Subgraph { public: - MishSubgraph3() + SoftplusSubgraph2() { int input = addNodeToMatch(""); int exp = addNodeToMatch("Exp", input); int addVal = addNodeToMatch(""); int add = addNodeToMatch("Add", exp, addVal); - int log = addNodeToMatch("Log", add); - int tanh = addNodeToMatch("Tanh", log); - addNodeToMatch("Mul", input, tanh); - setFusedNode("Mish", input); + addNodeToMatch("Log", add); + setFusedNode("Softplus", input); } }; @@ -766,9 +763,9 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); - subgraphs.push_back(makePtr()); - subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr());