From fa56623458e4e582f16b216448e21bb9e664a9ec Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 3 Nov 2023 12:34:09 +0300 Subject: [PATCH] Merge pull request #24463 from dkurt:dnn_shared_nodes_fusion DNN graph fusion with shared nodes #24463 ### Pull Request Readiness Checklist For now, nodes from matched pattern are removed during the matching process so if nodes are used in similar subgraph, they cannot be found. required for https://github.com/opencv/opencv/pull/24397 **Merge with extra**: https://github.com/opencv/opencv_extra/pull/1115 A part from [model_name ](https://github.com/onnx/models/blob/main/vision/object_detection_segmentation/fcn/model/fcn-resnet101-11.onnx) with two Resize subgraphs with shared nodes: ![image](https://github.com/opencv/opencv/assets/25801568/611d89d9-12fb-4add-9218-13b10d2c086a) See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake --- modules/dnn/src/graph_simplifier.cpp | 50 +++++++++++++++++-- .../dnn/src/onnx/onnx_graph_simplifier.cpp | 28 +++++++++++ modules/dnn/test/test_onnx_importer.cpp | 10 +++- 3 files changed, 82 insertions(+), 6 deletions(-) diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp index e58e0e38e8..e1b6d6df40 100644 --- a/modules/dnn/src/graph_simplifier.cpp +++ b/modules/dnn/src/graph_simplifier.cpp @@ -165,10 +165,7 @@ void Subgraph::replace(const Ptr& net, const std::vector node = net->getNode(matchedNodesIds.back()); - for (int i = matchedNodesIds.size() - 2; i >= 0; --i) - net->removeNode(matchedNodesIds[i]); // Modify the last node to be a fused one. node->setType(fusedNodeOp); @@ -191,6 +188,7 @@ void simplifySubgraphs(const Ptr& net, { int numNodes = net->getNumNodes(); std::vector matchedNodesIds, targetNodesIds; + std::vector nodesToRemove; for (int j = 0; j < patterns.size(); ++j) { for (int i = 0; i < numNodes; ++i) @@ -198,10 +196,54 @@ void simplifySubgraphs(const Ptr& net, if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) { patterns[j]->replace(net, matchedNodesIds, targetNodesIds); - numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. + // Remove matched nodes except the last one. + nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1); } } } + + if (nodesToRemove.empty()) + return; + + // Collect reference counts for every node + std::vector refcounts(net->getNumNodes(), 0); + std::map nodeIds; + + // Register node outputs. + // Every usage of one of the node's outputs should be counted. + for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) { + for (int i = 0; i < net->getNumOutputs(nodeId); ++i) { + std::string name = net->getOutputName(nodeId, i); + nodeIds[name] = nodeId; + } + } + + for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) { + // Increase counters for node's inputs + auto node = net->getNode(nodeId); + for (int i = 0; i < node->getNumInputs(); ++i) { + std::string inpName = node->getInputName(i); + if (inpName.empty()) + continue; + CV_Assert(nodeIds.find(inpName) != nodeIds.end()); + refcounts[nodeIds[inpName]] += 1; + } + } + + // Remove all fused nodes. Indices expected to be in descending order. + std::sort(nodesToRemove.begin(), nodesToRemove.end(), [](int a, int b) { return a > b; }); + for (int nodeId : nodesToRemove) { + if (refcounts[nodeId] == 0) { + // Decrease references to node's inputs and remove node itself + auto node = net->getNode(nodeId); + for (int i = 0; i < node->getNumInputs(); ++i) { + std::string inpName = node->getInputName(i); + refcounts[nodeIds[inpName]] -= 1; + } + net->removeNode(nodeId); + refcounts[nodeId] = -1; // Same node cannot be removed twice + } + } } }} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index a43815dbe4..15f79c8769 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -1136,6 +1136,33 @@ public: } }; +class ResizeSubgraph3 : public Subgraph +{ +public: + ResizeSubgraph3() : Subgraph() + { + int shapeSrc = addNodeToMatch(""); + int input = addNodeToMatch(""); + + int shape_h = addNodeToMatch("Shape", shapeSrc); + int shape_w = addNodeToMatch("Shape", shapeSrc); + int gather_h = addNodeToMatch("Gather", shape_h, addNodeToMatch("Constant")); + int gather_w = addNodeToMatch("Gather", shape_w, addNodeToMatch("Constant")); + int unsqueeze_h = addNodeToMatch("Unsqueeze", gather_h); + int unsqueeze_w = addNodeToMatch("Unsqueeze", gather_w); + int concat1 = addNodeToMatch("Concat", unsqueeze_h, unsqueeze_w); + int cast = addNodeToMatch("Cast", concat1); + + int shape2 = addNodeToMatch("Shape", input); + int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); + int concat2 = addNodeToMatch("Concat", slice, cast); + addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2); + + setFusedNode("Upsample", input, shapeSrc); + } +}; + + class BatchNormalizationSubgraphBase : public Subgraph { public: @@ -1207,6 +1234,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index b7e4e73cbc..cea4ffd739 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -54,7 +54,8 @@ public: void testONNXModels(const String& basename, const Extension ext = npy, double l1 = 0, double lInf = 0, const bool useSoftmax = false, - bool checkNoFallbacks = true, int numInps = 1) + bool checkNoFallbacks = true, int numInps = 1, + bool testShapes = true) { String onnxmodel = _tf("models/" + basename + ".onnx", required); std::vector inps(numInps); @@ -76,7 +77,8 @@ public: Net net = readNetFromONNX(onnxmodel); ASSERT_FALSE(net.empty()); - testInputShapes(net, inps); + if (testShapes) + testInputShapes(net, inps); net.setPreferableBackend(backend); net.setPreferableTarget(target); @@ -248,6 +250,10 @@ TEST_P(Test_ONNX_layers, Gather_shared_indices) { testONNXModels("gather_shared_indices", npy, 0, 0, false, false, 1); } +TEST_P(Test_ONNX_layers, Two_resizes_with_shared_subgraphs) { + testONNXModels("two_resizes_with_shared_subgraphs", npy, 0, 0, false, false, 3, /*testShapes*/ false); +} + TEST_P(Test_ONNX_layers, Convolution3D) { if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16)