From e18d5e94c731dde931a802631409a303e7114ba1 Mon Sep 17 00:00:00 2001 From: "ashishiva3@gmail.com" Date: Sun, 1 Mar 2020 15:09:15 +0530 Subject: [PATCH] Gather-Cast, Mul-Cast fusion --- modules/dnn/src/graph_simplifier.cpp | 5 +- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 52 ++++++++++++++----- modules/dnn/test/test_onnx_importer.cpp | 1 + 3 files changed, 43 insertions(+), 15 deletions(-) diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp index c5073d8a01..166564c215 100644 --- a/modules/dnn/src/graph_simplifier.cpp +++ b/modules/dnn/src/graph_simplifier.cpp @@ -194,15 +194,14 @@ void simplifySubgraphs(const Ptr& net, { int numNodes = net->getNumNodes(); std::vector matchedNodesIds, targetNodesIds; - for (int i = 0; i < numNodes; ++i) + for (int j = 0; j < patterns.size(); ++j) { - for (int j = 0; j < patterns.size(); ++j) + for (int i = 0; i < numNodes; ++i) { if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) { patterns[j]->replace(net, matchedNodesIds, targetNodesIds); numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. - break; } } } diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 41a768d23c..fe96927840 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -154,6 +154,32 @@ private: int axis; }; +class GatherCastSubgraph : public Subgraph +{ +public: + GatherCastSubgraph() + { + int input = addNodeToMatch(""); + int index = addNodeToMatch("Constant"); + int gather = addNodeToMatch("Gather", input, index); + addNodeToMatch("Cast", gather); + setFusedNode("Gather", input, index); + } +}; + +class MulCastSubgraph : public Subgraph +{ +public: + MulCastSubgraph() + { + int input = addNodeToMatch(""); + int scaleNode = addNodeToMatch("Constant"); + int mul = addNodeToMatch("Mul", input, scaleNode); + addNodeToMatch("Cast", mul); + setFusedNode("Mul", input, scaleNode); + } +}; + class ExtractScalesSubgraph : public Subgraph { public: @@ -164,20 +190,16 @@ public: int indexH = addNodeToMatch("Constant"); int shape1 = addNodeToMatch("Shape", input); int gather1 = addNodeToMatch("Gather", shape1, indexH); - int castG1 = addNodeToMatch("Cast", gather1); scaleHNode = addNodeToMatch("Constant"); - int mul1 = addNodeToMatch("Mul", castG1, scaleHNode); - int castM1 = addNodeToMatch("Cast", mul1); - int floor1 = addNodeToMatch("Floor", castM1); + int mul1 = addNodeToMatch("Mul", gather1, scaleHNode); + int floor1 = addNodeToMatch("Floor", mul1); int indexW = addNodeToMatch("Constant"); int shape2 = addNodeToMatch("Shape", input); int gather2 = addNodeToMatch("Gather", shape2, indexW); - int castG2 = addNodeToMatch("Cast", gather2); scaleWNode = addNodeToMatch("Constant"); - int mul2 = addNodeToMatch("Mul", castG2, scaleWNode); - int castM2 = addNodeToMatch("Cast", mul2); - int floor2 = addNodeToMatch("Floor", castM2); + int mul2 = addNodeToMatch("Mul", gather2, scaleWNode); + int floor2 = addNodeToMatch("Floor", mul2); int unsqueeze1 = addNodeToMatch("Unsqueeze", floor1); int unsqueeze2 = addNodeToMatch("Unsqueeze", floor2); @@ -190,19 +212,23 @@ public: { opencv_onnx::NodeProto* constant_node = inputs[1].dynamicCast()->node; opencv_onnx::TensorProto tensor_proto = constant_node->attribute(0).t(); - float scaleW = getMatFromTensor(tensor_proto).at(0); + Mat scaleW = getMatFromTensor(tensor_proto); + CV_Assert(scaleW.total() == 1); + scaleW.convertTo(scaleW, CV_32F); constant_node = inputs[2].dynamicCast()->node; tensor_proto = constant_node->attribute(0).t(); - float scaleH = getMatFromTensor(tensor_proto).at(0); + Mat scaleH = getMatFromTensor(tensor_proto); + CV_Assert(scaleH.total() == 1); + scaleH.convertTo(scaleH, CV_32F); opencv_onnx::NodeProto* node = fusedNode.dynamicCast()->node; opencv_onnx::AttributeProto* attrH = node->add_attribute(); attrH->set_name("height_scale"); - attrH->set_i(scaleH); + attrH->set_i(scaleH.at(0)); opencv_onnx::AttributeProto* attrW = node->add_attribute(); attrW->set_name("width_scale"); - attrW->set_i(scaleW); + attrW->set_i(scaleW.at(0)); node->mutable_input()->DeleteSubrange(1, 2); // Remove two last inputs } @@ -267,6 +293,8 @@ public: void simplifySubgraphs(opencv_onnx::GraphProto& net) { std::vector > subgraphs; + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index d037263a15..bb7cba1180 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -320,6 +320,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); + testONNXModels("upsample_unfused_torch1.2"); testONNXModels("upsample_unfused_opset9_torch1.4"); testONNXModels("resize_nearest_unfused_opset11_torch1.4"); testONNXModels("resize_nearest_unfused_opset11_torch1.3");