mirror of
https://github.com/opencv/opencv.git
synced 2025-06-17 15:20:51 +08:00
Merge pull request #18296 from sl-sergei:fix_16783
Fix loading issue for Faster RCNN model from #16783 * Add a reproducer with multi-output Gather * Fix an issue with ONNX graph simplifier * fix build * Move checks to correct class * Minor changes for better code appearence
This commit is contained in:
parent
564d1a0f79
commit
2b82f8f12c
@ -260,6 +260,40 @@ public:
|
|||||||
addNodeToMatch("Cast", gather);
|
addNodeToMatch("Cast", gather);
|
||||||
setFusedNode("Gather", input, index);
|
setFusedNode("Gather", input, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
|
std::vector<int>& matchedNodesIds,
|
||||||
|
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||||
|
{
|
||||||
|
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
|
||||||
|
size_t matchedNodesNum = matchedNodesIds.size();
|
||||||
|
// Now we check if merging can be made for these Gather and Cast nodes
|
||||||
|
if (!retVal || matchedNodesNum < 2)
|
||||||
|
return retVal;
|
||||||
|
else {
|
||||||
|
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
|
||||||
|
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
||||||
|
if (node->getType() == "Cast") {
|
||||||
|
int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
|
||||||
|
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
|
||||||
|
if (inpNode->getType() == "Gather") {
|
||||||
|
int numNodes = net->getNumNodes();
|
||||||
|
std::string inpNodeName = node->getInputName(0);
|
||||||
|
for (int i = 0; i < numNodes; ++i) {
|
||||||
|
const Ptr<ImportNodeWrapper> node_to_check = net->getNode(i);
|
||||||
|
int numInp = node_to_check->getNumInputs();
|
||||||
|
for (int inp = 0; inp < numInp; ++inp) {
|
||||||
|
if (i != nodeToMatch && inpNodeName == node_to_check->getInputName(0)) {
|
||||||
|
// Another node has the same input node, so it cannot be merged.
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ExpandSubgraph : public Subgraph
|
class ExpandSubgraph : public Subgraph
|
||||||
|
@ -705,6 +705,11 @@ TEST_P(Test_ONNX_layers, Conv1d_variable_weight_bias)
|
|||||||
normAssert(ref, out, "", default_l1, default_lInf);
|
normAssert(ref, out, "", default_l1, default_lInf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, GatherMultiOutput)
|
||||||
|
{
|
||||||
|
testONNXModels("gather_multi_output");
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
INSTANTIATE_TEST_CASE_P(/*nothing*/, Test_ONNX_layers, dnnBackendsAndTargets());
|
||||||
|
|
||||||
class Test_ONNX_nets : public Test_ONNX_layers
|
class Test_ONNX_nets : public Test_ONNX_layers
|
||||||
|
Loading…
Reference in New Issue
Block a user