mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Merge pull request #24577 from dkurt:dnn_graph_match_stack
Fix graph fusion with commutative ops #24577 ### Pull Request Readiness Checklist resolves https://github.com/opencv/opencv/issues/24568 **Merge with extra**: https://github.com/opencv/opencv_extra/pull/1125 TODO: - [x] replace recursive function to sequential See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
848dd12a1f
commit
332748dd55
@ -81,26 +81,45 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
{
|
||||
matchedNodesIds.clear();
|
||||
|
||||
std::queue<int> nodesToMatch;
|
||||
std::queue<int> targetNodes;
|
||||
std::vector<std::pair<int, int> > matchings;
|
||||
matchings.reserve(nodes.size());
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(nodes.size() - 1);
|
||||
while (!nodesToMatch.empty())
|
||||
{
|
||||
int nodeToMatch = nodesToMatch.front();
|
||||
int targetNodeId = targetNodes.front();
|
||||
nodesToMatch.pop();
|
||||
targetNodes.pop();
|
||||
// Collection of all matchings states across branching.
|
||||
// If there is no commutative ops in the subgraph - there would be just a single map.
|
||||
std::vector<std::shared_ptr<std::map<int, int>>> matchCandidates;
|
||||
matchCandidates.push_back(makePtr<std::map<int, int>>());
|
||||
|
||||
if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
|
||||
matchings.end())
|
||||
struct State
|
||||
{
|
||||
int nodeToMatch;
|
||||
int targetNodeId;
|
||||
// Every state refers to current matchings pairs as well as
|
||||
// matchings from parent branches produced by commutative ops.
|
||||
std::vector<std::shared_ptr<std::map<int, int>>> matchings;
|
||||
|
||||
// When we register a matching pair we should register it in every parent branch.
|
||||
// This is actual for branching in case of commutative ops only.
|
||||
void addMatch(std::pair<int, int> match)
|
||||
{
|
||||
for (auto& m : matchings)
|
||||
m->insert(match);
|
||||
}
|
||||
};
|
||||
|
||||
std::queue<State> states;
|
||||
states.push({nodeId, (int)nodes.size() - 1, matchCandidates});
|
||||
|
||||
while (!states.empty())
|
||||
{
|
||||
auto state = states.front();
|
||||
states.pop();
|
||||
int nodeToMatch = state.nodeToMatch;
|
||||
int targetNodeId = state.targetNodeId;
|
||||
auto matchings = state.matchings.back();
|
||||
|
||||
if (matchings->find(targetNodeId) != matchings->end())
|
||||
continue;
|
||||
|
||||
// Empty placeholder matches with any input type
|
||||
if (nodes[targetNodeId].empty()) {
|
||||
matchings.push_back({targetNodeId, nodeToMatch});
|
||||
state.addMatch({targetNodeId, nodeToMatch});
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -112,42 +131,50 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
if (inputNodes.size() != node->getNumInputs())
|
||||
continue;
|
||||
|
||||
bool isCommutative = net->isCommutativeOp(node->getType());
|
||||
state.addMatch({targetNodeId, nodeToMatch});
|
||||
|
||||
for (int j = 0; j < inputNodes.size(); ++j)
|
||||
bool isCommutative = net->isCommutativeOp(node->getType());
|
||||
if (isCommutative)
|
||||
{
|
||||
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
|
||||
if (node->getInputName(j).empty())
|
||||
continue;
|
||||
nodeId = getInputNodeId(net, node, j);
|
||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
|
||||
if (isCommutative)
|
||||
if (inputNodes.size() != 2)
|
||||
CV_Error(Error::StsNotImplemented, "Commutative op fusion with more than 2 inputs");
|
||||
|
||||
auto newMatchings = makePtr<std::map<int, int>>(*matchings);
|
||||
matchCandidates.push_back(newMatchings);
|
||||
state.matchings.push_back(newMatchings);
|
||||
states.push({getInputNodeId(net, node, 0), inputNodes[0], state.matchings});
|
||||
states.push({getInputNodeId(net, node, 1), inputNodes[1], state.matchings});
|
||||
state.matchings.pop_back();
|
||||
|
||||
newMatchings = makePtr<std::map<int, int>>(*matchings);
|
||||
matchCandidates.push_back(newMatchings);
|
||||
state.matchings.push_back(newMatchings);
|
||||
states.push({getInputNodeId(net, node, 0), inputNodes[1], state.matchings});
|
||||
states.push({getInputNodeId(net, node, 1), inputNodes[0], state.matchings});
|
||||
state.matchings.pop_back();
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < inputNodes.size(); ++j)
|
||||
{
|
||||
for (int i = 0; i < inputNodes.size(); ++i)
|
||||
{
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(inputNodes[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(inputNodes[j]);
|
||||
nodeId = getInputNodeId(net, node, j);
|
||||
states.push({nodeId, inputNodes[j], state.matchings});
|
||||
}
|
||||
}
|
||||
matchings.push_back({targetNodeId, nodeToMatch});
|
||||
}
|
||||
if (matchings.size() != nodes.size())
|
||||
return false;
|
||||
|
||||
// Sort matched by pattern nodes order.
|
||||
std::sort(matchings.begin(), matchings.end());
|
||||
matchedNodesIds.resize(matchings.size());
|
||||
for (int i = 0; i < matchings.size(); ++i)
|
||||
for (auto& matchings : matchCandidates)
|
||||
{
|
||||
matchedNodesIds[i] = matchings[i].second;
|
||||
if (matchings->size() != nodes.size())
|
||||
continue;
|
||||
matchedNodesIds.resize(matchings->size());
|
||||
for (int i = 0; i < matchings->size(); ++i)
|
||||
{
|
||||
CV_Assert(matchings->find(i) != matchings->end());
|
||||
matchedNodesIds[i] = matchings->at(i);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
|
||||
|
@ -64,6 +64,12 @@ class ONNXGraphWrapper : public ImportGraphWrapper
|
||||
public:
|
||||
ONNXGraphWrapper(opencv_onnx::GraphProto& _net) : net(_net)
|
||||
{
|
||||
// Add a fake initializer with empty name.
|
||||
// Some ONNX models skip their inputs. For example,
|
||||
// Resize which has 4 inputs but 2 of them have empty names.
|
||||
// So we add a fake empty node to which such ops may refer as input.
|
||||
net.add_initializer();
|
||||
|
||||
numInputs = net.input_size();
|
||||
numInitializers = net.initializer_size();
|
||||
}
|
||||
|
@ -3539,7 +3539,7 @@ void ONNXImporter::parseQGemm(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
Mat bias;
|
||||
if (constBlobs.find(node_proto.input(6)) != constBlobs.end())
|
||||
bias = getBlob(node_proto, 6);
|
||||
else
|
||||
if (bias.empty())
|
||||
bias = Mat::zeros(1, outCn, CV_32S);
|
||||
|
||||
Mat biasFused(1, outCn, CV_32S);
|
||||
|
@ -35,6 +35,7 @@ class Test_Graph_Simplifier : public ::testing::Test {
|
||||
|
||||
TEST_F(Test_Graph_Simplifier, GeluSubGraph) {
|
||||
test("gelu", "Gelu");
|
||||
test("bias_gelu", std::vector<std::string>{"Gelu", "NaryEltwise"});
|
||||
}
|
||||
|
||||
TEST_F(Test_Graph_Simplifier, GeluApproximationSubGraph) {
|
||||
|
Loading…
Reference in New Issue
Block a user