diff --git a/modules/dnn/src/graph_simplifier.cpp b/modules/dnn/src/graph_simplifier.cpp index e58e0e38e8..e1b6d6df40 100644 --- a/modules/dnn/src/graph_simplifier.cpp +++ b/modules/dnn/src/graph_simplifier.cpp @@ -165,10 +165,7 @@ void Subgraph::replace(const Ptr& net, const std::vector node = net->getNode(matchedNodesIds.back()); - for (int i = matchedNodesIds.size() - 2; i >= 0; --i) - net->removeNode(matchedNodesIds[i]); // Modify the last node to be a fused one. node->setType(fusedNodeOp); @@ -191,6 +188,7 @@ void simplifySubgraphs(const Ptr& net, { int numNodes = net->getNumNodes(); std::vector matchedNodesIds, targetNodesIds; + std::vector nodesToRemove; for (int j = 0; j < patterns.size(); ++j) { for (int i = 0; i < numNodes; ++i) @@ -198,10 +196,54 @@ void simplifySubgraphs(const Ptr& net, if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds)) { patterns[j]->replace(net, matchedNodesIds, targetNodesIds); - numNodes -= matchedNodesIds.size() - 1; // #matchedNodes removed and one added. + // Remove matched nodes except the last one. + nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1); } } } + + if (nodesToRemove.empty()) + return; + + // Collect reference counts for every node + std::vector refcounts(net->getNumNodes(), 0); + std::map nodeIds; + + // Register node outputs. + // Every usage of one of the node's outputs should be counted. + for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) { + for (int i = 0; i < net->getNumOutputs(nodeId); ++i) { + std::string name = net->getOutputName(nodeId, i); + nodeIds[name] = nodeId; + } + } + + for (int nodeId = 0; nodeId < refcounts.size(); ++nodeId) { + // Increase counters for node's inputs + auto node = net->getNode(nodeId); + for (int i = 0; i < node->getNumInputs(); ++i) { + std::string inpName = node->getInputName(i); + if (inpName.empty()) + continue; + CV_Assert(nodeIds.find(inpName) != nodeIds.end()); + refcounts[nodeIds[inpName]] += 1; + } + } + + // Remove all fused nodes. Indices expected to be in descending order. + std::sort(nodesToRemove.begin(), nodesToRemove.end(), [](int a, int b) { return a > b; }); + for (int nodeId : nodesToRemove) { + if (refcounts[nodeId] == 0) { + // Decrease references to node's inputs and remove node itself + auto node = net->getNode(nodeId); + for (int i = 0; i < node->getNumInputs(); ++i) { + std::string inpName = node->getInputName(i); + refcounts[nodeIds[inpName]] -= 1; + } + net->removeNode(nodeId); + refcounts[nodeId] = -1; // Same node cannot be removed twice + } + } } }} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index a43815dbe4..15f79c8769 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -1136,6 +1136,33 @@ public: } }; +class ResizeSubgraph3 : public Subgraph +{ +public: + ResizeSubgraph3() : Subgraph() + { + int shapeSrc = addNodeToMatch(""); + int input = addNodeToMatch(""); + + int shape_h = addNodeToMatch("Shape", shapeSrc); + int shape_w = addNodeToMatch("Shape", shapeSrc); + int gather_h = addNodeToMatch("Gather", shape_h, addNodeToMatch("Constant")); + int gather_w = addNodeToMatch("Gather", shape_w, addNodeToMatch("Constant")); + int unsqueeze_h = addNodeToMatch("Unsqueeze", gather_h); + int unsqueeze_w = addNodeToMatch("Unsqueeze", gather_w); + int concat1 = addNodeToMatch("Concat", unsqueeze_h, unsqueeze_w); + int cast = addNodeToMatch("Cast", concat1); + + int shape2 = addNodeToMatch("Shape", input); + int slice = addNodeToMatch("Slice", shape2, addNodeToMatch("Constant"), addNodeToMatch("Constant"), addNodeToMatch("Constant")); + int concat2 = addNodeToMatch("Concat", slice, cast); + addNodeToMatch("Resize", input, addNodeToMatch("Constant"), addNodeToMatch("Constant"), concat2); + + setFusedNode("Upsample", input, shapeSrc); + } +}; + + class BatchNormalizationSubgraphBase : public Subgraph { public: @@ -1207,6 +1234,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index b7e4e73cbc..cea4ffd739 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -54,7 +54,8 @@ public: void testONNXModels(const String& basename, const Extension ext = npy, double l1 = 0, double lInf = 0, const bool useSoftmax = false, - bool checkNoFallbacks = true, int numInps = 1) + bool checkNoFallbacks = true, int numInps = 1, + bool testShapes = true) { String onnxmodel = _tf("models/" + basename + ".onnx", required); std::vector inps(numInps); @@ -76,7 +77,8 @@ public: Net net = readNetFromONNX(onnxmodel); ASSERT_FALSE(net.empty()); - testInputShapes(net, inps); + if (testShapes) + testInputShapes(net, inps); net.setPreferableBackend(backend); net.setPreferableTarget(target); @@ -248,6 +250,10 @@ TEST_P(Test_ONNX_layers, Gather_shared_indices) { testONNXModels("gather_shared_indices", npy, 0, 0, false, false, 1); } +TEST_P(Test_ONNX_layers, Two_resizes_with_shared_subgraphs) { + testONNXModels("two_resizes_with_shared_subgraphs", npy, 0, 0, false, false, 3, /*testShapes*/ false); +} + TEST_P(Test_ONNX_layers, Convolution3D) { if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16)