mirror of
https://github.com/opencv/opencv.git
synced 2025-08-05 22:19:14 +08:00
Merge pull request #24483 from dkurt:dnn_fusion_commutative_ops
Commutative rules for DNN subgraphs fusion #24483 ### Pull Request Readiness Checklist related: https://github.com/opencv/opencv/pull/24463#issuecomment-1783033931 See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
41c335e5a5
commit
b7ec2ebb55
@ -77,14 +77,14 @@ int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
|
||||
}
|
||||
|
||||
bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds)
|
||||
std::vector<int>& matchedNodesIds)
|
||||
{
|
||||
matchedNodesIds.clear();
|
||||
targetNodesIds.clear();
|
||||
|
||||
std::queue<int> nodesToMatch;
|
||||
std::queue<int> targetNodes;
|
||||
std::vector<std::pair<int, int> > matchings;
|
||||
matchings.reserve(nodes.size());
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(nodes.size() - 1);
|
||||
while (!nodesToMatch.empty())
|
||||
@ -94,51 +94,63 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
nodesToMatch.pop();
|
||||
targetNodes.pop();
|
||||
|
||||
if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) !=
|
||||
matchedNodesIds.end())
|
||||
if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
|
||||
matchings.end())
|
||||
continue;
|
||||
|
||||
// Empty placeholder matches with any input type
|
||||
if (nodes[targetNodeId].empty()) {
|
||||
matchings.push_back({targetNodeId, nodeToMatch});
|
||||
continue;
|
||||
}
|
||||
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
||||
if (node->getType() != nodes[targetNodeId])
|
||||
return false;
|
||||
continue;
|
||||
|
||||
std::vector<int>& inputNodes = inputs[targetNodeId];
|
||||
if (inputNodes.size() != node->getNumInputs())
|
||||
return false;
|
||||
continue;
|
||||
|
||||
bool isCommutative = net->isCommutativeOp(node->getType());
|
||||
|
||||
for (int j = 0; j < inputNodes.size(); ++j)
|
||||
{
|
||||
if (nodes[inputNodes[j]].empty() || node->getInputName(j).empty()) // Unknown input node type.
|
||||
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
|
||||
if (node->getInputName(j).empty())
|
||||
continue;
|
||||
nodeId = getInputNodeId(net, node, j);
|
||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
|
||||
if (inpNode->getType() != "Const" && inpNode->getType() != "Constant")
|
||||
if (isCommutative)
|
||||
{
|
||||
for (int i = 0; i < inputNodes.size(); ++i)
|
||||
{
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(inputNodes[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
nodesToMatch.push(nodeId);
|
||||
targetNodes.push(inputNodes[j]);
|
||||
}
|
||||
else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
|
||||
return false;
|
||||
}
|
||||
matchedNodesIds.push_back(nodeToMatch);
|
||||
targetNodesIds.push_back(targetNodeId);
|
||||
matchings.push_back({targetNodeId, nodeToMatch});
|
||||
}
|
||||
if (matchings.size() != nodes.size())
|
||||
return false;
|
||||
|
||||
const int n = matchedNodesIds.size();
|
||||
std::vector<std::pair<int, int> > elements(n);
|
||||
for (int i = 0; i < n; ++i)
|
||||
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]);
|
||||
std::sort(elements.begin(), elements.end());
|
||||
for (int i = 0; i < n; ++i)
|
||||
// Sort matched by pattern nodes order.
|
||||
std::sort(matchings.begin(), matchings.end());
|
||||
matchedNodesIds.resize(matchings.size());
|
||||
for (int i = 0; i < matchings.size(); ++i)
|
||||
{
|
||||
matchedNodesIds[i] = elements[i].first;
|
||||
targetNodesIds[i] = elements[i].second;
|
||||
matchedNodesIds[i] = matchings[i].second;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
|
||||
const std::vector<int>& targetNodesIds)
|
||||
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
|
||||
{
|
||||
// Extract names of input nodes.
|
||||
std::vector<std::string> inputsNames(fusedNodeInputs.size());
|
||||
@ -149,9 +161,9 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
|
||||
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
|
||||
{
|
||||
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
|
||||
std::vector<int>& inpIndices = inputs[targetNodesIds[j]];
|
||||
std::vector<int>& inpIndices = inputs[j];
|
||||
|
||||
CV_Assert(node->getNumInputs() == inpIndices.size());
|
||||
CV_Assert(inpIndices.empty() || node->getNumInputs() == inpIndices.size());
|
||||
for (int k = 0; k < inpIndices.size(); ++k)
|
||||
{
|
||||
if (inpIndices[k] == fusedNodeInputs[i])
|
||||
@ -187,15 +199,15 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
|
||||
const std::vector<Ptr<Subgraph> >& patterns)
|
||||
{
|
||||
int numNodes = net->getNumNodes();
|
||||
std::vector<int> matchedNodesIds, targetNodesIds;
|
||||
std::vector<int> matchedNodesIds;
|
||||
std::vector<int> nodesToRemove;
|
||||
for (int j = 0; j < patterns.size(); ++j)
|
||||
{
|
||||
for (int i = 0; i < numNodes; ++i)
|
||||
{
|
||||
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
|
||||
if (patterns[j]->match(net, i, matchedNodesIds))
|
||||
{
|
||||
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
|
||||
patterns[j]->replace(net, matchedNodesIds);
|
||||
// Remove matched nodes except the last one.
|
||||
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
|
||||
}
|
||||
|
@ -44,6 +44,8 @@ public:
|
||||
virtual std::string getOutputName(int nodeId, int outId) const = 0;
|
||||
|
||||
virtual void removeNode(int idx) = 0;
|
||||
|
||||
virtual bool isCommutativeOp(const std::string& type) const = 0;
|
||||
};
|
||||
|
||||
class Subgraph // Interface to match and replace subgraphs.
|
||||
@ -75,12 +77,10 @@ public:
|
||||
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
|
||||
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds);
|
||||
std::vector<int>& matchedNodesIds);
|
||||
|
||||
// Fuse matched subgraph.
|
||||
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
|
||||
const std::vector<int>& targetNodesIds);
|
||||
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds);
|
||||
|
||||
virtual void finalize(const Ptr<ImportGraphWrapper>& net,
|
||||
const Ptr<ImportNodeWrapper>& fusedNode,
|
||||
|
@ -125,8 +125,13 @@ public:
|
||||
|
||||
virtual void removeNode(int idx) CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(idx >= numInputs + numInitializers);
|
||||
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
|
||||
if (idx >= numInputs + numInitializers)
|
||||
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
|
||||
}
|
||||
|
||||
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
|
||||
{
|
||||
return type == "Add" || type == "Mul" || type == "Equal" || type == "Max";
|
||||
}
|
||||
|
||||
private:
|
||||
@ -134,6 +139,25 @@ private:
|
||||
opencv_onnx::GraphProto& net;
|
||||
};
|
||||
|
||||
static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1)
|
||||
{
|
||||
return onnx_net->getMatFromInitializer(initializer_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = Subgraph::getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||
return getMatFromTensor(constant_proto);
|
||||
}
|
||||
}
|
||||
|
||||
/* Fusion for Gelu.
|
||||
|
||||
Graph before fusion:
|
||||
@ -151,54 +175,32 @@ public:
|
||||
GeluSubGraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
|
||||
div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
|
||||
int erf = addNodeToMatch("Erf", div);
|
||||
int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
|
||||
add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
|
||||
int mul = addNodeToMatch("Mul", input, add);
|
||||
addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
|
||||
mul2 = addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
|
||||
|
||||
setFusedNode("Gelu", input);
|
||||
}
|
||||
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1)
|
||||
{
|
||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
||||
return *const_mat.ptr<float>();
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||
Mat constant_mat = getMatFromTensor(constant_proto);
|
||||
return *constant_mat.ptr<float>();
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
// Check Div[B=sqrt(2)]
|
||||
float divisor = extractConstant(net, matchedNodesIds[0], 1);
|
||||
float divisor = extractConstant(net, matchedNodesIds[div], 1).at<float>(0);
|
||||
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
|
||||
return false;
|
||||
|
||||
// Check Add[B=1]
|
||||
float add_const = extractConstant(net, matchedNodesIds[2], 1);
|
||||
float add_const = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
|
||||
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
|
||||
return false;
|
||||
|
||||
// Check Mul[B=0.5]
|
||||
float mul_const = extractConstant(net, matchedNodesIds[4], 1);
|
||||
float mul_const = extractConstant(net, matchedNodesIds[mul2], 1).at<float>(0);
|
||||
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
|
||||
return false;
|
||||
|
||||
@ -206,6 +208,9 @@ public:
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int div, add, mul2;
|
||||
};
|
||||
|
||||
/* Fusion for GeluApproximation.
|
||||
@ -229,61 +234,39 @@ public:
|
||||
int input = addNodeToMatch("");
|
||||
int mul0 = addNodeToMatch("Mul", input, input);
|
||||
int mul1 = addNodeToMatch("Mul", input, mul0);
|
||||
int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
|
||||
mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
|
||||
int add0 = addNodeToMatch("Add", input, mul2);
|
||||
int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
|
||||
mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
|
||||
int tanh = addNodeToMatch("Tanh", mul3);
|
||||
int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
|
||||
add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
|
||||
int mul4 = addNodeToMatch("Mul", input, add1);
|
||||
addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
|
||||
mul5 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
|
||||
|
||||
setFusedNode("GeluApproximation", input);
|
||||
}
|
||||
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1)
|
||||
{
|
||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
||||
return *const_mat.ptr<float>();
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||
Mat constant_mat = getMatFromTensor(constant_proto);
|
||||
return *constant_mat.ptr<float>();
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
// Check Mul[A=0.044714998453855515]
|
||||
float coef = extractConstant(net, matchedNodesIds[2], 0);
|
||||
float coef = extractConstant(net, matchedNodesIds[mul2], 0).at<float>(0);
|
||||
if (coef - 0.044714998453855515 >= 1e-6)
|
||||
return false;
|
||||
|
||||
// Check Mul[A=sqrt(2/pie)]
|
||||
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0);
|
||||
float sqrt_2_pie = extractConstant(net, matchedNodesIds[mul3], 0).at<float>(0);
|
||||
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
|
||||
return false;
|
||||
|
||||
// Check Add[A=1]
|
||||
float add_const = extractConstant(net, matchedNodesIds[6], 0);
|
||||
float add_const = extractConstant(net, matchedNodesIds[add1], 0).at<float>(0);
|
||||
if (add_const - 1.f >= 1e-6)
|
||||
return false;
|
||||
|
||||
// Check Mul[A=0.5]
|
||||
float mul_const = extractConstant(net, matchedNodesIds[8], 0);
|
||||
float mul_const = extractConstant(net, matchedNodesIds[mul5], 0).at<float>(0);
|
||||
if (mul_const - 0.5f >= 1e-6)
|
||||
return false;
|
||||
|
||||
@ -291,6 +274,9 @@ public:
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int mul2, mul3, add1, mul5;
|
||||
};
|
||||
|
||||
/* Fusion for LayerNormalization.
|
||||
@ -313,43 +299,22 @@ public:
|
||||
LayerNormSubGraph() : axis(-1), epsilon(1e-5)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int mean = addNodeToMatch("ReduceMean", input);
|
||||
mean = addNodeToMatch("ReduceMean", input);
|
||||
|
||||
int sub = addNodeToMatch("Sub", input, mean);
|
||||
|
||||
int pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
|
||||
int mean1 = addNodeToMatch("ReduceMean", pow);
|
||||
int add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
|
||||
pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
|
||||
mean1 = addNodeToMatch("ReduceMean", pow);
|
||||
add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
|
||||
int sqrt = addNodeToMatch("Sqrt", add);
|
||||
|
||||
int div = addNodeToMatch("Div", sub, sqrt);
|
||||
int mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
|
||||
addNodeToMatch("Add", mul, addNodeToMatch(""));
|
||||
mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
|
||||
bias = addNodeToMatch("Add", mul, addNodeToMatch(""));
|
||||
|
||||
setFusedNode("LayerNormalization", input);
|
||||
}
|
||||
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1) // initializer
|
||||
{
|
||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
||||
return *const_mat.ptr<float>();
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||
Mat constant_mat = getMatFromTensor(constant_proto);
|
||||
return *constant_mat.ptr<float>();
|
||||
}
|
||||
}
|
||||
|
||||
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
|
||||
{
|
||||
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
|
||||
@ -381,25 +346,24 @@ public:
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
float pow_exp = extractConstant(net, matchedNodesIds[2], 1);
|
||||
float pow_exp = extractConstant(net, matchedNodesIds[pow], 1).at<float>(0);
|
||||
if (pow_exp - 2 > 1e-5) // not pow(2)
|
||||
return false;
|
||||
|
||||
int axis_mean1 = extractAxis(net, matchedNodesIds[0]);
|
||||
int axis_mean2 = extractAxis(net, matchedNodesIds[3]);
|
||||
int axis_mean1 = extractAxis(net, matchedNodesIds[mean]);
|
||||
int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]);
|
||||
if (axis_mean1 != axis_mean2)
|
||||
return false;
|
||||
axis = axis_mean1;
|
||||
|
||||
epsilon = extractConstant(net, matchedNodesIds[4], 1);
|
||||
epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
|
||||
|
||||
weight_name = getInputName(net, matchedNodesIds[7], 1);
|
||||
bias_name = getInputName(net, matchedNodesIds[8], 1);
|
||||
weight_name = getInputName(net, matchedNodesIds[mul], 1);
|
||||
bias_name = getInputName(net, matchedNodesIds[bias], 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -429,6 +393,7 @@ protected:
|
||||
float epsilon;
|
||||
std::string weight_name;
|
||||
std::string bias_name;
|
||||
int pow, mean, mean1, add, mul, bias;
|
||||
};
|
||||
|
||||
class SoftMaxSubgraphBase : public Subgraph
|
||||
@ -437,10 +402,9 @@ public:
|
||||
SoftMaxSubgraphBase() : axis(1), id(-1) {}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
CV_Assert(id >= 0 && id < matchedNodesIds.size());
|
||||
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
|
||||
@ -485,7 +449,7 @@ public:
|
||||
int inpExp = addNodeToMatch("Exp", input);
|
||||
|
||||
int sum = addNodeToMatch("ReduceSum", inpExp);
|
||||
id = 1;
|
||||
id = sum;
|
||||
|
||||
addNodeToMatch("Div", inpExp, sum);
|
||||
setFusedNode("Softmax", input);
|
||||
@ -498,7 +462,7 @@ public:
|
||||
int input = addNodeToMatch("");
|
||||
|
||||
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||
id = 0;
|
||||
id = reducemax;
|
||||
|
||||
int sub = addNodeToMatch("Sub", input, reducemax);
|
||||
int exp = addNodeToMatch("Exp", sub);
|
||||
@ -516,7 +480,7 @@ public:
|
||||
int input = addNodeToMatch("");
|
||||
|
||||
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||
id = 0;
|
||||
id = reducemax;
|
||||
|
||||
int sub_1 = addNodeToMatch("Sub", input, reducemax);
|
||||
int exp = addNodeToMatch("Exp", sub_1);
|
||||
@ -533,18 +497,17 @@ public:
|
||||
HardSwishSubgraph()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int hardSigmoid = addNodeToMatch("HardSigmoid", input);
|
||||
addNodeToMatch("Mul", input, hardSigmoid);
|
||||
hardSigmoidId = addNodeToMatch("HardSigmoid", input);
|
||||
addNodeToMatch("Mul", input, hardSigmoidId);
|
||||
setFusedNode("HardSwish", input);
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[0]);
|
||||
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[hardSigmoidId]);
|
||||
opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
|
||||
uint8_t matched = 0;
|
||||
@ -561,6 +524,9 @@ public:
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
int hardSigmoidId;
|
||||
};
|
||||
|
||||
class CeluSubgraph : public Subgraph
|
||||
@ -569,9 +535,9 @@ public:
|
||||
CeluSubgraph() : alpha(1.f)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int div = addNodeToMatch("Div", input, addNodeToMatch(""));
|
||||
int elu = addNodeToMatch("Elu", div);
|
||||
addNodeToMatch("Mul", addNodeToMatch(""), elu);
|
||||
div = addNodeToMatch("Div", input, addNodeToMatch(""));
|
||||
elu = addNodeToMatch("Elu", div);
|
||||
mul = addNodeToMatch("Mul", addNodeToMatch(""), elu);
|
||||
setFusedNode("Celu", input);
|
||||
}
|
||||
|
||||
@ -587,16 +553,15 @@ public:
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
float alpha_div = extractAlpha(net, matchedNodesIds[0], 1);
|
||||
float alpha_mul = extractAlpha(net, matchedNodesIds[2], 0);
|
||||
float alpha_div = extractAlpha(net, matchedNodesIds[div], 1);
|
||||
float alpha_mul = extractAlpha(net, matchedNodesIds[mul], 0);
|
||||
float alpha_elu = 1.f;
|
||||
|
||||
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[1]);
|
||||
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[elu]);
|
||||
opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
|
||||
for (int i = 0; i < elu_node->attribute_size(); i++)
|
||||
@ -625,18 +590,18 @@ public:
|
||||
|
||||
protected:
|
||||
float alpha;
|
||||
int div, mul, elu;
|
||||
};
|
||||
|
||||
class NormalizeSubgraphBase : public Subgraph
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
|
||||
NormalizeSubgraphBase(int _normNodeOrder = 1) : axis(1), normNodeOrder(_normNodeOrder) {}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
{
|
||||
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
|
||||
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
@ -725,7 +690,7 @@ public:
|
||||
class NormalizeSubgraph3 : public NormalizeSubgraphBase
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph3() : NormalizeSubgraphBase(1)
|
||||
NormalizeSubgraph3() : NormalizeSubgraphBase(3)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int power = addNodeToMatch("Constant");
|
||||
@ -743,7 +708,7 @@ public:
|
||||
class NormalizeSubgraph4 : public NormalizeSubgraphBase
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph4() : NormalizeSubgraphBase(1)
|
||||
NormalizeSubgraph4() : NormalizeSubgraphBase(2)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int mul = addNodeToMatch("Mul", input, input);
|
||||
@ -760,7 +725,7 @@ public:
|
||||
class NormalizeSubgraph5 : public NormalizeSubgraphBase
|
||||
{
|
||||
public:
|
||||
NormalizeSubgraph5() : NormalizeSubgraphBase(1)
|
||||
NormalizeSubgraph5() : NormalizeSubgraphBase(2)
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int mul = addNodeToMatch("Mul", input, input);
|
||||
@ -781,25 +746,24 @@ public:
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int index = addNodeToMatch("Constant");
|
||||
int gather = addNodeToMatch("Gather", input, index);
|
||||
addNodeToMatch("Cast", gather);
|
||||
gather = addNodeToMatch("Gather", input, index);
|
||||
cast = addNodeToMatch("Cast", gather);
|
||||
setFusedNode("Gather", input, index);
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
|
||||
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds);
|
||||
size_t matchedNodesNum = matchedNodesIds.size();
|
||||
// Now we check if merging can be made for these Gather and Cast nodes
|
||||
if (!retVal || matchedNodesNum < 2)
|
||||
return retVal;
|
||||
else {
|
||||
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
|
||||
int nodeToMatch = matchedNodesIds[cast];
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
||||
if (node->getType() == "Cast") {
|
||||
int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
|
||||
int inpNodeId = matchedNodesIds[gather];
|
||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
|
||||
if (inpNode->getType() == "Gather") {
|
||||
int numNodes = net->getNumNodes();
|
||||
@ -819,6 +783,9 @@ public:
|
||||
}
|
||||
return retVal;
|
||||
}
|
||||
|
||||
private:
|
||||
int cast, gather;
|
||||
};
|
||||
|
||||
/* Constant folding shape for Expand.
|
||||
@ -838,12 +805,12 @@ public:
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int values = addNodeToMatch("");
|
||||
int init = addNodeToMatch("ConstantOfShape", values);
|
||||
init = addNodeToMatch("ConstantOfShape", values);
|
||||
int coeff = addNodeToMatch("Constant");
|
||||
int mul = addNodeToMatch("Mul", init, coeff);
|
||||
mul = addNodeToMatch("Mul", init, coeff);
|
||||
int shape = addNodeToMatch("Constant");
|
||||
int condition = addNodeToMatch("Equal", shape, mul);
|
||||
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
|
||||
condition = addNodeToMatch("Equal", shape, mul);
|
||||
where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
|
||||
addNodeToMatch("Expand", input, where);
|
||||
setFusedNode("Expand", input, shape);
|
||||
}
|
||||
@ -872,53 +839,28 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::vector<int64_t> extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
Mat mat_constant;
|
||||
if (initializer_id != -1) // initializer
|
||||
{
|
||||
mat_constant = onnx_net->getMatFromInitializer(initializer_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||
mat_constant = getMatFromTensor(constant_proto);
|
||||
}
|
||||
|
||||
std::vector<int64_t> retvals{mat_constant.begin<int>(), mat_constant.end<int>()};
|
||||
return retvals;
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE {
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) {
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
|
||||
int64_t value_ConstantOfShape;
|
||||
if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) {
|
||||
if (!extractValue(net, matchedNodesIds[init], value_ConstantOfShape)) {
|
||||
return false;
|
||||
}
|
||||
std::vector<int64_t> input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0);
|
||||
std::vector<int> input_ConstantOfShape = extractConstant(net, matchedNodesIds[init], 0);
|
||||
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto B_Mul = extractConstant(net, matchedNodesIds[1], 1);
|
||||
std::vector<int> B_Mul = extractConstant(net, matchedNodesIds[mul], 1);
|
||||
if (B_Mul.size() != static_cast<size_t>(1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto A_Equal = extractConstant(net, matchedNodesIds[2], 0);
|
||||
std::vector<int> A_Equal = extractConstant(net, matchedNodesIds[condition], 0);
|
||||
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto Y_Where = extractConstant(net, matchedNodesIds[3], 2);
|
||||
std::vector<int> Y_Where = extractConstant(net, matchedNodesIds[where], 2);
|
||||
if (Y_Where.size() != A_Equal.size()) {
|
||||
return false;
|
||||
}
|
||||
@ -969,6 +911,9 @@ public:
|
||||
|
||||
protected:
|
||||
std::vector<int64_t> shape;
|
||||
|
||||
private:
|
||||
int init, mul, condition, where;
|
||||
};
|
||||
|
||||
class MishSubgraph : public Subgraph
|
||||
@ -979,7 +924,7 @@ public:
|
||||
int input = addNodeToMatch("");
|
||||
int softplus = addNodeToMatch("Softplus", input);
|
||||
int tanh = addNodeToMatch("Tanh", softplus);
|
||||
addNodeToMatch("Mul", input, tanh);
|
||||
addNodeToMatch("Mul", tanh, input);
|
||||
setFusedNode("Mish", input);
|
||||
}
|
||||
};
|
||||
@ -999,20 +944,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class SoftplusSubgraph2: public Subgraph
|
||||
{
|
||||
public:
|
||||
SoftplusSubgraph2()
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int exp = addNodeToMatch("Exp", input);
|
||||
int addVal = addNodeToMatch("");
|
||||
int add = addNodeToMatch("Add", exp, addVal);
|
||||
addNodeToMatch("Log", add);
|
||||
setFusedNode("Softplus", input);
|
||||
}
|
||||
};
|
||||
|
||||
class MulCastSubgraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
@ -1248,7 +1179,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
||||
subgraphs.push_back(makePtr<SoftplusSubgraph>());
|
||||
subgraphs.push_back(makePtr<SoftplusSubgraph2>());
|
||||
subgraphs.push_back(makePtr<MishSubgraph>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
|
||||
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
|
||||
|
@ -98,6 +98,14 @@ public:
|
||||
net.mutable_node()->DeleteSubrange(idx, 1);
|
||||
}
|
||||
|
||||
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
|
||||
{
|
||||
return type == "Add" || type == "Sum" ||
|
||||
type == "Mul" || type == "Prod" ||
|
||||
type == "Max" || type == "Maximum" || type == "Minimum" ||
|
||||
type == "Mean" || type == "SquaredDifference";
|
||||
}
|
||||
|
||||
tensorflow::GraphDef& net;
|
||||
};
|
||||
|
||||
@ -282,24 +290,26 @@ public:
|
||||
{
|
||||
int input = addNodeToMatch("");
|
||||
int relu = addNodeToMatch("Relu", input);
|
||||
int maxValue = addNodeToMatch("Const");
|
||||
maxValueId = addNodeToMatch("Const");
|
||||
int clipValue = addNodeToMatch("Const");
|
||||
int minimum = addNodeToMatch("Minimum", relu, maxValue);
|
||||
int minimum = addNodeToMatch("Minimum", relu, maxValueId);
|
||||
addNodeToMatch("Maximum", minimum, clipValue);
|
||||
|
||||
setFusedNode("Relu6", input);
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
if (!Subgraph::match(net, nodeId, matchedNodesIds))
|
||||
return false;
|
||||
tensorflow::NodeDef* node = net->getNode(matchedNodesIds.front() + 1).dynamicCast<TFNodeWrapper>()->node;
|
||||
tensorflow::NodeDef* node = net->getNode(matchedNodesIds[maxValueId]).dynamicCast<TFNodeWrapper>()->node;
|
||||
Mat maxValue = getTensorContent(node->attr().at("value").tensor());
|
||||
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
|
||||
}
|
||||
|
||||
private:
|
||||
int maxValueId;
|
||||
};
|
||||
|
||||
// Keras' reshape stores output shape in separate Const nodes by one value.
|
||||
@ -328,15 +338,14 @@ public:
|
||||
}
|
||||
|
||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||
std::vector<int>& matchedNodesIds,
|
||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
||||
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||
{
|
||||
Ptr<ImportNodeWrapper> node = net->getNode(nodeId);
|
||||
if (node->getNumInputs() == 0)
|
||||
return false;
|
||||
|
||||
inpName = node->getInputName(0);
|
||||
return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
|
||||
return Subgraph::match(net, nodeId, matchedNodesIds);
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user