mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 20:09:23 +08:00
Merge pull request #18296 from sl-sergei:fix_16783
Fix loading issue for Faster RCNN model from #16783 * Add a reproducer with multi-output Gather * Fix an issue with ONNX graph simplifier * fix build * Move checks to correct class * Minor changes for better code appearence
This commit is contained in:
parent
564d1a0f79
commit
2b82f8f12c
@ -260,6 +260,40 @@ public:
|
||||
addNodeToMatch("Cast", gather);
|
||||
setFusedNode("Gather", input, index);
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
{
|
||||
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
|
||||
size_t matchedNodesNum = matchedNodesIds.size();
|
||||
// Now we check if merging can be made for these Gather and Cast nodes
|
||||
if (!retVal || matchedNodesNum < 2)
|
||||
return retVal;
|
||||
else {
|
||||
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
||||
if (node->getType() == "Cast") {
|
||||
int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
|
||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
|
||||
if (inpNode->getType() == "Gather") {
|
||||
int numNodes = net->getNumNodes();
|
||||
std::string inpNodeName = node->getInputName(0);
|
||||
for (int i = 0; i < numNodes; ++i) {
|
||||
const Ptr<ImportNodeWrapper> node_to_check = net->getNode(i);
|
||||
int numInp = node_to_check->getNumInputs();
|
||||
for (int inp = 0; inp < numInp; ++inp) {
|
||||
if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) {
|
||||
// Another node has the same input node, so it cannot be merged.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return retVal;
|
||||
}
|
||||
};
|
||||
|
||||
class ExpandSubgraph : public Subgraph
|
||||
|
@ -705,6 +705,11 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias)
|
||||
normAssert(ref, out, "", default_l1, default_lInf);
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, GatherMultiOutput)
|
||||
{
|
||||
testONNXModels("gather_multi_output");
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||
|
||||
class Test_ONNX_nets : public Test_ONNX_layers
|
||||
|
Loading…
Reference in New Issue
Block a user