From d3f9ad11459bfc63f04246771dfc42f6d275700a Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Sat, 28 Mar 2020 18:53:57 +0300 Subject: [PATCH] Enable ONNX SSD from https://github.com/amdegroot/ssd.pytorch --- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 54 ++++++++++++++----- modules/dnn/src/onnx/onnx_importer.cpp | 19 +++++++ modules/dnn/test/test_onnx_importer.cpp | 1 + 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index ff474224cc..bf992feb2c 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -154,16 +154,10 @@ private: int axis; }; -class NormalizeSubgraph1 : public Subgraph +class NormalizeSubgraphBase : public Subgraph { public: - NormalizeSubgraph1() : axis(1) - { - input = addNodeToMatch(""); - norm = addNodeToMatch("ReduceL2", input); - addNodeToMatch("Div", input, norm); - setFusedNode("Normalize", input); - } + NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {} virtual bool match(const Ptr& net, int nodeId, std::vector& matchedNodesIds, @@ -171,7 +165,7 @@ public: { if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) { - Ptr norm = net->getNode(matchedNodesIds[0]); + Ptr norm = net->getNode(matchedNodesIds[normNodeOrder]); opencv_onnx::NodeProto* node = norm.dynamicCast()->node; for (int i = 0; i < node->attribute_size(); i++) @@ -204,20 +198,51 @@ public: } protected: - int input, norm; - int axis; + int axis, normNodeOrder; }; - -class NormalizeSubgraph2 : public NormalizeSubgraph1 +class NormalizeSubgraph1 : public NormalizeSubgraphBase { public: - NormalizeSubgraph2() : NormalizeSubgraph1() + NormalizeSubgraph1() { + int input = addNodeToMatch(""); + int norm = addNodeToMatch("ReduceL2", input); + addNodeToMatch("Div", input, norm); + setFusedNode("Normalize", input); + } +}; + +class NormalizeSubgraph2 : public NormalizeSubgraphBase +{ +public: + NormalizeSubgraph2() + { + int input = addNodeToMatch(""); + int norm = addNodeToMatch("ReduceL2", input); int clip = addNodeToMatch("Clip", norm); int shape = addNodeToMatch("Shape", input); int expand = addNodeToMatch("Expand", clip, shape); addNodeToMatch("Div", input, expand); + setFusedNode("Normalize", input); + } +}; + +class NormalizeSubgraph3 : public NormalizeSubgraphBase +{ +public: + NormalizeSubgraph3() : NormalizeSubgraphBase(1) + { + int input = addNodeToMatch(""); + int power = addNodeToMatch("Constant"); + int squared = addNodeToMatch("Pow", input, power); + int sum = addNodeToMatch("ReduceSum", squared); + int sqrtNode = addNodeToMatch("Sqrt", sum); + int eps = addNodeToMatch("Constant"); + int add = addNodeToMatch("Add", sqrtNode, eps); + + addNodeToMatch("Div", input, add); + setFusedNode("Normalize", input); } }; @@ -368,6 +393,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); simplifySubgraphs(Ptr(new ONNXGraphWrapper(net)), subgraphs); } diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 9c60dcaae9..716c9ad39c 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1457,6 +1457,25 @@ void ONNXImporter::populateNet(Net dstNet) layerParams.type = "Softmax"; layerParams.set("log_softmax", layer_type == "LogSoftmax"); } + else if (layer_type == "DetectionOutput") + { + CV_CheckEQ(node_proto.input_size(), 3, ""); + if (constBlobs.find(node_proto.input(2)) != constBlobs.end()) + { + Mat priors = getBlob(node_proto, constBlobs, 2); + + LayerParams constParams; + constParams.name = layerParams.name + "/priors"; + constParams.type = "Const"; + constParams.blobs.push_back(priors); + + opencv_onnx::NodeProto priorsProto; + priorsProto.add_output(constParams.name); + addLayer(dstNet, constParams, priorsProto, layer_id, outShapes); + + node_proto.set_input(2, constParams.name); + } + } else { for (int j = 0; j < node_proto.input_size(); j++) { diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 6f36f8d3d1..6e47a86631 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -440,6 +440,7 @@ TEST_P(Test_ONNX_layers, ReduceL2) { testONNXModels("reduceL2"); testONNXModels("reduceL2_subgraph"); + testONNXModels("reduceL2_subgraph_2"); } TEST_P(Test_ONNX_layers, Split)