diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index e8b237cab4..30c0b26ead 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -260,6 +260,40 @@ public: addNodeToMatch("Cast", gather); setFusedNode("Gather", input, index); } + + virtual bool match(const Ptr& net, int nodeId, + std::vector& matchedNodesIds, + std::vector& targetNodesIds) CV_OVERRIDE + { + bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds); + size_t matchedNodesNum = matchedNodesIds.size(); + // Now we check if merging can be made for these Gather and Cast nodes + if (!retVal || matchedNodesNum < 2) + return retVal; + else { + int nodeToMatch = matchedNodesIds[matchedNodesNum - 1]; + const Ptr node = net->getNode(nodeToMatch); + if (node->getType() == "Cast") { + int inpNodeId = matchedNodesIds[matchedNodesNum - 2]; + const Ptr inpNode = net->getNode(inpNodeId); + if (inpNode->getType() == "Gather") { + int numNodes = net->getNumNodes(); + std::string inpNodeName = node->getInputName(0); + for (int i = 0; i < numNodes; ++i) { + const Ptr node_to_check = net->getNode(i); + int numInp = node_to_check->getNumInputs(); + for (int inp = 0; inp < numInp; ++inp) { + if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) { + // Another node has the same input node, so it cannot be merged. + return false; + } + } + } + } + } + } + return retVal; + } }; class ExpandSubgraph : public Subgraph diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 5c6de55da5..14d2d28522 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -705,6 +705,11 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias) normAssert(ref, out, "", default_l1, default_lInf); } +TEST_P(Test_ONNX_layers, GatherMultiOutput) +{ + testONNXModels("gather_multi_output"); +} + INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets()); class Test_ONNX_nets : public Test_ONNX_layers