mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 14:36:36 +08:00
Merge pull request #24483 from dkurt:dnn_fusion_commutative_ops
Commutative rules for DNN subgraphs fusion #24483 ### Pull Request Readiness Checklist related: https://github.com/opencv/opencv/pull/24463#issuecomment-1783033931 See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
41c335e5a5
commit
b7ec2ebb55
@ -77,14 +77,14 @@ int Subgraph::getInputNodeId(const Ptr<ImportGraphWrapper>& net,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds)
|
||||||
std::vector<int>& targetNodesIds)
|
|
||||||
{
|
{
|
||||||
matchedNodesIds.clear();
|
matchedNodesIds.clear();
|
||||||
targetNodesIds.clear();
|
|
||||||
|
|
||||||
std::queue<int> nodesToMatch;
|
std::queue<int> nodesToMatch;
|
||||||
std::queue<int> targetNodes;
|
std::queue<int> targetNodes;
|
||||||
|
std::vector<std::pair<int, int> > matchings;
|
||||||
|
matchings.reserve(nodes.size());
|
||||||
nodesToMatch.push(nodeId);
|
nodesToMatch.push(nodeId);
|
||||||
targetNodes.push(nodes.size() - 1);
|
targetNodes.push(nodes.size() - 1);
|
||||||
while (!nodesToMatch.empty())
|
while (!nodesToMatch.empty())
|
||||||
@ -94,51 +94,63 @@ bool Subgraph::match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
|||||||
nodesToMatch.pop();
|
nodesToMatch.pop();
|
||||||
targetNodes.pop();
|
targetNodes.pop();
|
||||||
|
|
||||||
if (std::find(matchedNodesIds.begin(), matchedNodesIds.end(), nodeToMatch) !=
|
if (std::find_if(matchings.begin(), matchings.end(), [&](const std::pair<int, int>& match){ return match.first == targetNodeId; }) !=
|
||||||
matchedNodesIds.end())
|
matchings.end())
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
// Empty placeholder matches with any input type
|
||||||
|
if (nodes[targetNodeId].empty()) {
|
||||||
|
matchings.push_back({targetNodeId, nodeToMatch});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
||||||
if (node->getType() != nodes[targetNodeId])
|
if (node->getType() != nodes[targetNodeId])
|
||||||
return false;
|
continue;
|
||||||
|
|
||||||
std::vector<int>& inputNodes = inputs[targetNodeId];
|
std::vector<int>& inputNodes = inputs[targetNodeId];
|
||||||
if (inputNodes.size() != node->getNumInputs())
|
if (inputNodes.size() != node->getNumInputs())
|
||||||
return false;
|
continue;
|
||||||
|
|
||||||
|
bool isCommutative = net->isCommutativeOp(node->getType());
|
||||||
|
|
||||||
for (int j = 0; j < inputNodes.size(); ++j)
|
for (int j = 0; j < inputNodes.size(); ++j)
|
||||||
{
|
{
|
||||||
if (nodes[inputNodes[j]].empty() || node->getInputName(j).empty()) // Unknown input node type.
|
// Sometimes, ONNX may have input but it's empty (see Clip layer from reduceL2_subgraph2_2 testcase)
|
||||||
|
if (node->getInputName(j).empty())
|
||||||
continue;
|
continue;
|
||||||
nodeId = getInputNodeId(net, node, j);
|
nodeId = getInputNodeId(net, node, j);
|
||||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
|
const Ptr<ImportNodeWrapper> inpNode = net->getNode(nodeId);
|
||||||
if (inpNode->getType() != "Const" && inpNode->getType() != "Constant")
|
if (isCommutative)
|
||||||
|
{
|
||||||
|
for (int i = 0; i < inputNodes.size(); ++i)
|
||||||
|
{
|
||||||
|
nodesToMatch.push(nodeId);
|
||||||
|
targetNodes.push(inputNodes[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
{
|
{
|
||||||
nodesToMatch.push(nodeId);
|
nodesToMatch.push(nodeId);
|
||||||
targetNodes.push(inputNodes[j]);
|
targetNodes.push(inputNodes[j]);
|
||||||
}
|
}
|
||||||
else if (nodes[inputNodes[j]] != "Const" && nodes[inputNodes[j]] != "Constant")
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
matchedNodesIds.push_back(nodeToMatch);
|
matchings.push_back({targetNodeId, nodeToMatch});
|
||||||
targetNodesIds.push_back(targetNodeId);
|
|
||||||
}
|
}
|
||||||
|
if (matchings.size() != nodes.size())
|
||||||
|
return false;
|
||||||
|
|
||||||
const int n = matchedNodesIds.size();
|
// Sort matched by pattern nodes order.
|
||||||
std::vector<std::pair<int, int> > elements(n);
|
std::sort(matchings.begin(), matchings.end());
|
||||||
for (int i = 0; i < n; ++i)
|
matchedNodesIds.resize(matchings.size());
|
||||||
elements[i] = std::make_pair(matchedNodesIds[i], targetNodesIds[i]);
|
for (int i = 0; i < matchings.size(); ++i)
|
||||||
std::sort(elements.begin(), elements.end());
|
|
||||||
for (int i = 0; i < n; ++i)
|
|
||||||
{
|
{
|
||||||
matchedNodesIds[i] = elements[i].first;
|
matchedNodesIds[i] = matchings[i].second;
|
||||||
targetNodesIds[i] = elements[i].second;
|
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
|
void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds)
|
||||||
const std::vector<int>& targetNodesIds)
|
|
||||||
{
|
{
|
||||||
// Extract names of input nodes.
|
// Extract names of input nodes.
|
||||||
std::vector<std::string> inputsNames(fusedNodeInputs.size());
|
std::vector<std::string> inputsNames(fusedNodeInputs.size());
|
||||||
@ -149,9 +161,9 @@ void Subgraph::replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int
|
|||||||
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
|
for (int j = 0; j < matchedNodesIds.size() && inpName.empty(); ++j)
|
||||||
{
|
{
|
||||||
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
|
Ptr<ImportNodeWrapper> node = net->getNode(matchedNodesIds[j]);
|
||||||
std::vector<int>& inpIndices = inputs[targetNodesIds[j]];
|
std::vector<int>& inpIndices = inputs[j];
|
||||||
|
|
||||||
CV_Assert(node->getNumInputs() == inpIndices.size());
|
CV_Assert(inpIndices.empty() || node->getNumInputs() == inpIndices.size());
|
||||||
for (int k = 0; k < inpIndices.size(); ++k)
|
for (int k = 0; k < inpIndices.size(); ++k)
|
||||||
{
|
{
|
||||||
if (inpIndices[k] == fusedNodeInputs[i])
|
if (inpIndices[k] == fusedNodeInputs[i])
|
||||||
@ -187,15 +199,15 @@ void simplifySubgraphs(const Ptr<ImportGraphWrapper>& net,
|
|||||||
const std::vector<Ptr<Subgraph> >& patterns)
|
const std::vector<Ptr<Subgraph> >& patterns)
|
||||||
{
|
{
|
||||||
int numNodes = net->getNumNodes();
|
int numNodes = net->getNumNodes();
|
||||||
std::vector<int> matchedNodesIds, targetNodesIds;
|
std::vector<int> matchedNodesIds;
|
||||||
std::vector<int> nodesToRemove;
|
std::vector<int> nodesToRemove;
|
||||||
for (int j = 0; j < patterns.size(); ++j)
|
for (int j = 0; j < patterns.size(); ++j)
|
||||||
{
|
{
|
||||||
for (int i = 0; i < numNodes; ++i)
|
for (int i = 0; i < numNodes; ++i)
|
||||||
{
|
{
|
||||||
if (patterns[j]->match(net, i, matchedNodesIds, targetNodesIds))
|
if (patterns[j]->match(net, i, matchedNodesIds))
|
||||||
{
|
{
|
||||||
patterns[j]->replace(net, matchedNodesIds, targetNodesIds);
|
patterns[j]->replace(net, matchedNodesIds);
|
||||||
// Remove matched nodes except the last one.
|
// Remove matched nodes except the last one.
|
||||||
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
|
nodesToRemove.insert(nodesToRemove.end(), matchedNodesIds.begin(), matchedNodesIds.end() - 1);
|
||||||
}
|
}
|
||||||
|
@ -44,6 +44,8 @@ public:
|
|||||||
virtual std::string getOutputName(int nodeId, int outId) const = 0;
|
virtual std::string getOutputName(int nodeId, int outId) const = 0;
|
||||||
|
|
||||||
virtual void removeNode(int idx) = 0;
|
virtual void removeNode(int idx) = 0;
|
||||||
|
|
||||||
|
virtual bool isCommutativeOp(const std::string& type) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Subgraph // Interface to match and replace subgraphs.
|
class Subgraph // Interface to match and replace subgraphs.
|
||||||
@ -75,12 +77,10 @@ public:
|
|||||||
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
|
// Match TensorFlow subgraph starting from <nodeId> with a set of nodes to be fused.
|
||||||
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
|
// Const nodes are skipped during matching. Returns true if nodes are matched and can be fused.
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds);
|
||||||
std::vector<int>& targetNodesIds);
|
|
||||||
|
|
||||||
// Fuse matched subgraph.
|
// Fuse matched subgraph.
|
||||||
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds,
|
void replace(const Ptr<ImportGraphWrapper>& net, const std::vector<int>& matchedNodesIds);
|
||||||
const std::vector<int>& targetNodesIds);
|
|
||||||
|
|
||||||
virtual void finalize(const Ptr<ImportGraphWrapper>& net,
|
virtual void finalize(const Ptr<ImportGraphWrapper>& net,
|
||||||
const Ptr<ImportNodeWrapper>& fusedNode,
|
const Ptr<ImportNodeWrapper>& fusedNode,
|
||||||
|
@ -125,8 +125,13 @@ public:
|
|||||||
|
|
||||||
virtual void removeNode(int idx) CV_OVERRIDE
|
virtual void removeNode(int idx) CV_OVERRIDE
|
||||||
{
|
{
|
||||||
CV_Assert(idx >= numInputs + numInitializers);
|
if (idx >= numInputs + numInitializers)
|
||||||
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
|
net.mutable_node()->DeleteSubrange(idx - numInputs - numInitializers, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
|
||||||
|
{
|
||||||
|
return type == "Add" || type == "Mul" || type == "Equal" || type == "Max";
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -134,6 +139,25 @@ private:
|
|||||||
opencv_onnx::GraphProto& net;
|
opencv_onnx::GraphProto& net;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||||
|
{
|
||||||
|
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||||
|
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||||
|
if (initializer_id != -1)
|
||||||
|
{
|
||||||
|
return onnx_net->getMatFromInitializer(initializer_id);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||||
|
int constant_id = Subgraph::getInputNodeId(net, node, input_id);
|
||||||
|
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||||
|
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||||
|
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
||||||
|
return getMatFromTensor(constant_proto);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* Fusion for Gelu.
|
/* Fusion for Gelu.
|
||||||
|
|
||||||
Graph before fusion:
|
Graph before fusion:
|
||||||
@ -151,54 +175,32 @@ public:
|
|||||||
GeluSubGraph()
|
GeluSubGraph()
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
|
div = addNodeToMatch("Div", input, addNodeToMatch("") /* B=sqrt(2) */ );
|
||||||
int erf = addNodeToMatch("Erf", div);
|
int erf = addNodeToMatch("Erf", div);
|
||||||
int add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
|
add = addNodeToMatch("Add", erf, addNodeToMatch("") /* B=1 */ );
|
||||||
int mul = addNodeToMatch("Mul", input, add);
|
int mul = addNodeToMatch("Mul", input, add);
|
||||||
addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
|
mul2 = addNodeToMatch("Mul", mul, addNodeToMatch("") /* B=0.5 */) ;
|
||||||
|
|
||||||
setFusedNode("Gelu", input);
|
setFusedNode("Gelu", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
|
||||||
{
|
|
||||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
|
||||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
|
||||||
if (initializer_id != -1)
|
|
||||||
{
|
|
||||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
|
||||||
return *const_mat.ptr<float>();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
|
||||||
int constant_id = getInputNodeId(net, node, input_id);
|
|
||||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
|
||||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
|
||||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
|
||||||
Mat constant_mat = getMatFromTensor(constant_proto);
|
|
||||||
return *constant_mat.ptr<float>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
// Check Div[B=sqrt(2)]
|
// Check Div[B=sqrt(2)]
|
||||||
float divisor = extractConstant(net, matchedNodesIds[0], 1);
|
float divisor = extractConstant(net, matchedNodesIds[div], 1).at<float>(0);
|
||||||
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
|
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check Add[B=1]
|
// Check Add[B=1]
|
||||||
float add_const = extractConstant(net, matchedNodesIds[2], 1);
|
float add_const = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
|
||||||
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
|
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check Mul[B=0.5]
|
// Check Mul[B=0.5]
|
||||||
float mul_const = extractConstant(net, matchedNodesIds[4], 1);
|
float mul_const = extractConstant(net, matchedNodesIds[mul2], 1).at<float>(0);
|
||||||
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
|
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -206,6 +208,9 @@ public:
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int div, add, mul2;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Fusion for GeluApproximation.
|
/* Fusion for GeluApproximation.
|
||||||
@ -229,61 +234,39 @@ public:
|
|||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int mul0 = addNodeToMatch("Mul", input, input);
|
int mul0 = addNodeToMatch("Mul", input, input);
|
||||||
int mul1 = addNodeToMatch("Mul", input, mul0);
|
int mul1 = addNodeToMatch("Mul", input, mul0);
|
||||||
int mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
|
mul2 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.044714998453855515 */, mul1);
|
||||||
int add0 = addNodeToMatch("Add", input, mul2);
|
int add0 = addNodeToMatch("Add", input, mul2);
|
||||||
int mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
|
mul3 = addNodeToMatch("Mul", addNodeToMatch("") /* A=sqrt(2/pie) */, add0);
|
||||||
int tanh = addNodeToMatch("Tanh", mul3);
|
int tanh = addNodeToMatch("Tanh", mul3);
|
||||||
int add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
|
add1 = addNodeToMatch("Add", addNodeToMatch("") /* A=1 */, tanh);
|
||||||
int mul4 = addNodeToMatch("Mul", input, add1);
|
int mul4 = addNodeToMatch("Mul", input, add1);
|
||||||
addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
|
mul5 = addNodeToMatch("Mul", addNodeToMatch("") /* A=0.5 */, mul4);
|
||||||
|
|
||||||
setFusedNode("GeluApproximation", input);
|
setFusedNode("GeluApproximation", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
|
||||||
{
|
|
||||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
|
||||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
|
||||||
if (initializer_id != -1)
|
|
||||||
{
|
|
||||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
|
||||||
return *const_mat.ptr<float>();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
|
||||||
int constant_id = getInputNodeId(net, node, input_id);
|
|
||||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
|
||||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
|
||||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
|
||||||
Mat constant_mat = getMatFromTensor(constant_proto);
|
|
||||||
return *constant_mat.ptr<float>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
// Check Mul[A=0.044714998453855515]
|
// Check Mul[A=0.044714998453855515]
|
||||||
float coef = extractConstant(net, matchedNodesIds[2], 0);
|
float coef = extractConstant(net, matchedNodesIds[mul2], 0).at<float>(0);
|
||||||
if (coef - 0.044714998453855515 >= 1e-6)
|
if (coef - 0.044714998453855515 >= 1e-6)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check Mul[A=sqrt(2/pie)]
|
// Check Mul[A=sqrt(2/pie)]
|
||||||
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0);
|
float sqrt_2_pie = extractConstant(net, matchedNodesIds[mul3], 0).at<float>(0);
|
||||||
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
|
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check Add[A=1]
|
// Check Add[A=1]
|
||||||
float add_const = extractConstant(net, matchedNodesIds[6], 0);
|
float add_const = extractConstant(net, matchedNodesIds[add1], 0).at<float>(0);
|
||||||
if (add_const - 1.f >= 1e-6)
|
if (add_const - 1.f >= 1e-6)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check Mul[A=0.5]
|
// Check Mul[A=0.5]
|
||||||
float mul_const = extractConstant(net, matchedNodesIds[8], 0);
|
float mul_const = extractConstant(net, matchedNodesIds[mul5], 0).at<float>(0);
|
||||||
if (mul_const - 0.5f >= 1e-6)
|
if (mul_const - 0.5f >= 1e-6)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
@ -291,6 +274,9 @@ public:
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int mul2, mul3, add1, mul5;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Fusion for LayerNormalization.
|
/* Fusion for LayerNormalization.
|
||||||
@ -313,43 +299,22 @@ public:
|
|||||||
LayerNormSubGraph() : axis(-1), epsilon(1e-5)
|
LayerNormSubGraph() : axis(-1), epsilon(1e-5)
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int mean = addNodeToMatch("ReduceMean", input);
|
mean = addNodeToMatch("ReduceMean", input);
|
||||||
|
|
||||||
int sub = addNodeToMatch("Sub", input, mean);
|
int sub = addNodeToMatch("Sub", input, mean);
|
||||||
|
|
||||||
int pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
|
pow = addNodeToMatch("Pow", sub, addNodeToMatch(""));
|
||||||
int mean1 = addNodeToMatch("ReduceMean", pow);
|
mean1 = addNodeToMatch("ReduceMean", pow);
|
||||||
int add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
|
add = addNodeToMatch("Add", mean1, addNodeToMatch(""));
|
||||||
int sqrt = addNodeToMatch("Sqrt", add);
|
int sqrt = addNodeToMatch("Sqrt", add);
|
||||||
|
|
||||||
int div = addNodeToMatch("Div", sub, sqrt);
|
int div = addNodeToMatch("Div", sub, sqrt);
|
||||||
int mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
|
mul = addNodeToMatch("Mul", div, addNodeToMatch(""));
|
||||||
addNodeToMatch("Add", mul, addNodeToMatch(""));
|
bias = addNodeToMatch("Add", mul, addNodeToMatch(""));
|
||||||
|
|
||||||
setFusedNode("LayerNormalization", input);
|
setFusedNode("LayerNormalization", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
|
||||||
{
|
|
||||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
|
||||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
|
||||||
if (initializer_id != -1) // initializer
|
|
||||||
{
|
|
||||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
|
||||||
return *const_mat.ptr<float>();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
|
||||||
int constant_id = getInputNodeId(net, node, input_id);
|
|
||||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
|
||||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
|
||||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
|
||||||
Mat constant_mat = getMatFromTensor(constant_proto);
|
|
||||||
return *constant_mat.ptr<float>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
|
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
|
||||||
{
|
{
|
||||||
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
|
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
|
||||||
@ -381,25 +346,24 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
float pow_exp = extractConstant(net, matchedNodesIds[2], 1);
|
float pow_exp = extractConstant(net, matchedNodesIds[pow], 1).at<float>(0);
|
||||||
if (pow_exp - 2 > 1e-5) // not pow(2)
|
if (pow_exp - 2 > 1e-5) // not pow(2)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
int axis_mean1 = extractAxis(net, matchedNodesIds[0]);
|
int axis_mean1 = extractAxis(net, matchedNodesIds[mean]);
|
||||||
int axis_mean2 = extractAxis(net, matchedNodesIds[3]);
|
int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]);
|
||||||
if (axis_mean1 != axis_mean2)
|
if (axis_mean1 != axis_mean2)
|
||||||
return false;
|
return false;
|
||||||
axis = axis_mean1;
|
axis = axis_mean1;
|
||||||
|
|
||||||
epsilon = extractConstant(net, matchedNodesIds[4], 1);
|
epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
|
||||||
|
|
||||||
weight_name = getInputName(net, matchedNodesIds[7], 1);
|
weight_name = getInputName(net, matchedNodesIds[mul], 1);
|
||||||
bias_name = getInputName(net, matchedNodesIds[8], 1);
|
bias_name = getInputName(net, matchedNodesIds[bias], 1);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -429,6 +393,7 @@ protected:
|
|||||||
float epsilon;
|
float epsilon;
|
||||||
std::string weight_name;
|
std::string weight_name;
|
||||||
std::string bias_name;
|
std::string bias_name;
|
||||||
|
int pow, mean, mean1, add, mul, bias;
|
||||||
};
|
};
|
||||||
|
|
||||||
class SoftMaxSubgraphBase : public Subgraph
|
class SoftMaxSubgraphBase : public Subgraph
|
||||||
@ -437,10 +402,9 @@ public:
|
|||||||
SoftMaxSubgraphBase() : axis(1), id(-1) {}
|
SoftMaxSubgraphBase() : axis(1), id(-1) {}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
CV_Assert(id >= 0 && id < matchedNodesIds.size());
|
CV_Assert(id >= 0 && id < matchedNodesIds.size());
|
||||||
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
|
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
|
||||||
@ -485,7 +449,7 @@ public:
|
|||||||
int inpExp = addNodeToMatch("Exp", input);
|
int inpExp = addNodeToMatch("Exp", input);
|
||||||
|
|
||||||
int sum = addNodeToMatch("ReduceSum", inpExp);
|
int sum = addNodeToMatch("ReduceSum", inpExp);
|
||||||
id = 1;
|
id = sum;
|
||||||
|
|
||||||
addNodeToMatch("Div", inpExp, sum);
|
addNodeToMatch("Div", inpExp, sum);
|
||||||
setFusedNode("Softmax", input);
|
setFusedNode("Softmax", input);
|
||||||
@ -498,7 +462,7 @@ public:
|
|||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
|
|
||||||
int reducemax = addNodeToMatch("ReduceMax", input);
|
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||||
id = 0;
|
id = reducemax;
|
||||||
|
|
||||||
int sub = addNodeToMatch("Sub", input, reducemax);
|
int sub = addNodeToMatch("Sub", input, reducemax);
|
||||||
int exp = addNodeToMatch("Exp", sub);
|
int exp = addNodeToMatch("Exp", sub);
|
||||||
@ -516,7 +480,7 @@ public:
|
|||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
|
|
||||||
int reducemax = addNodeToMatch("ReduceMax", input);
|
int reducemax = addNodeToMatch("ReduceMax", input);
|
||||||
id = 0;
|
id = reducemax;
|
||||||
|
|
||||||
int sub_1 = addNodeToMatch("Sub", input, reducemax);
|
int sub_1 = addNodeToMatch("Sub", input, reducemax);
|
||||||
int exp = addNodeToMatch("Exp", sub_1);
|
int exp = addNodeToMatch("Exp", sub_1);
|
||||||
@ -533,18 +497,17 @@ public:
|
|||||||
HardSwishSubgraph()
|
HardSwishSubgraph()
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int hardSigmoid = addNodeToMatch("HardSigmoid", input);
|
hardSigmoidId = addNodeToMatch("HardSigmoid", input);
|
||||||
addNodeToMatch("Mul", input, hardSigmoid);
|
addNodeToMatch("Mul", input, hardSigmoidId);
|
||||||
setFusedNode("HardSwish", input);
|
setFusedNode("HardSwish", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[0]);
|
Ptr<ImportNodeWrapper> hardSigmoid = net->getNode(matchedNodesIds[hardSigmoidId]);
|
||||||
opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node;
|
opencv_onnx::NodeProto* node = hardSigmoid.dynamicCast<ONNXNodeWrapper>()->node;
|
||||||
|
|
||||||
uint8_t matched = 0;
|
uint8_t matched = 0;
|
||||||
@ -561,6 +524,9 @@ public:
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int hardSigmoidId;
|
||||||
};
|
};
|
||||||
|
|
||||||
class CeluSubgraph : public Subgraph
|
class CeluSubgraph : public Subgraph
|
||||||
@ -569,9 +535,9 @@ public:
|
|||||||
CeluSubgraph() : alpha(1.f)
|
CeluSubgraph() : alpha(1.f)
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int div = addNodeToMatch("Div", input, addNodeToMatch(""));
|
div = addNodeToMatch("Div", input, addNodeToMatch(""));
|
||||||
int elu = addNodeToMatch("Elu", div);
|
elu = addNodeToMatch("Elu", div);
|
||||||
addNodeToMatch("Mul", addNodeToMatch(""), elu);
|
mul = addNodeToMatch("Mul", addNodeToMatch(""), elu);
|
||||||
setFusedNode("Celu", input);
|
setFusedNode("Celu", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -587,16 +553,15 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
float alpha_div = extractAlpha(net, matchedNodesIds[0], 1);
|
float alpha_div = extractAlpha(net, matchedNodesIds[div], 1);
|
||||||
float alpha_mul = extractAlpha(net, matchedNodesIds[2], 0);
|
float alpha_mul = extractAlpha(net, matchedNodesIds[mul], 0);
|
||||||
float alpha_elu = 1.f;
|
float alpha_elu = 1.f;
|
||||||
|
|
||||||
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[1]);
|
Ptr<ImportNodeWrapper> elu_ptr = net->getNode(matchedNodesIds[elu]);
|
||||||
opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
opencv_onnx::NodeProto* elu_node = elu_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||||
|
|
||||||
for (int i = 0; i < elu_node->attribute_size(); i++)
|
for (int i = 0; i < elu_node->attribute_size(); i++)
|
||||||
@ -625,18 +590,18 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
float alpha;
|
float alpha;
|
||||||
|
int div, mul, elu;
|
||||||
};
|
};
|
||||||
|
|
||||||
class NormalizeSubgraphBase : public Subgraph
|
class NormalizeSubgraphBase : public Subgraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
NormalizeSubgraphBase(int _normNodeOrder = 0) : axis(1), normNodeOrder(_normNodeOrder) {}
|
NormalizeSubgraphBase(int _normNodeOrder = 1) : axis(1), normNodeOrder(_normNodeOrder) {}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
{
|
{
|
||||||
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
|
Ptr<ImportNodeWrapper> norm = net->getNode(matchedNodesIds[normNodeOrder]);
|
||||||
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
|
opencv_onnx::NodeProto* node = norm.dynamicCast<ONNXNodeWrapper>()->node;
|
||||||
@ -725,7 +690,7 @@ public:
|
|||||||
class NormalizeSubgraph3 : public NormalizeSubgraphBase
|
class NormalizeSubgraph3 : public NormalizeSubgraphBase
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
NormalizeSubgraph3() : NormalizeSubgraphBase(1)
|
NormalizeSubgraph3() : NormalizeSubgraphBase(3)
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int power = addNodeToMatch("Constant");
|
int power = addNodeToMatch("Constant");
|
||||||
@ -743,7 +708,7 @@ public:
|
|||||||
class NormalizeSubgraph4 : public NormalizeSubgraphBase
|
class NormalizeSubgraph4 : public NormalizeSubgraphBase
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
NormalizeSubgraph4() : NormalizeSubgraphBase(1)
|
NormalizeSubgraph4() : NormalizeSubgraphBase(2)
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int mul = addNodeToMatch("Mul", input, input);
|
int mul = addNodeToMatch("Mul", input, input);
|
||||||
@ -760,7 +725,7 @@ public:
|
|||||||
class NormalizeSubgraph5 : public NormalizeSubgraphBase
|
class NormalizeSubgraph5 : public NormalizeSubgraphBase
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
NormalizeSubgraph5() : NormalizeSubgraphBase(1)
|
NormalizeSubgraph5() : NormalizeSubgraphBase(2)
|
||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int mul = addNodeToMatch("Mul", input, input);
|
int mul = addNodeToMatch("Mul", input, input);
|
||||||
@ -781,25 +746,24 @@ public:
|
|||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int index = addNodeToMatch("Constant");
|
int index = addNodeToMatch("Constant");
|
||||||
int gather = addNodeToMatch("Gather", input, index);
|
gather = addNodeToMatch("Gather", input, index);
|
||||||
addNodeToMatch("Cast", gather);
|
cast = addNodeToMatch("Cast", gather);
|
||||||
setFusedNode("Gather", input, index);
|
setFusedNode("Gather", input, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
|
bool retVal = Subgraph::match(net, nodeId, matchedNodesIds);
|
||||||
size_t matchedNodesNum = matchedNodesIds.size();
|
size_t matchedNodesNum = matchedNodesIds.size();
|
||||||
// Now we check if merging can be made for these Gather and Cast nodes
|
// Now we check if merging can be made for these Gather and Cast nodes
|
||||||
if (!retVal || matchedNodesNum < 2)
|
if (!retVal || matchedNodesNum < 2)
|
||||||
return retVal;
|
return retVal;
|
||||||
else {
|
else {
|
||||||
int nodeToMatch = matchedNodesIds[matchedNodesNum - 1];
|
int nodeToMatch = matchedNodesIds[cast];
|
||||||
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
const Ptr<ImportNodeWrapper> node = net->getNode(nodeToMatch);
|
||||||
if (node->getType() == "Cast") {
|
if (node->getType() == "Cast") {
|
||||||
int inpNodeId = matchedNodesIds[matchedNodesNum - 2];
|
int inpNodeId = matchedNodesIds[gather];
|
||||||
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
|
const Ptr<ImportNodeWrapper> inpNode = net->getNode(inpNodeId);
|
||||||
if (inpNode->getType() == "Gather") {
|
if (inpNode->getType() == "Gather") {
|
||||||
int numNodes = net->getNumNodes();
|
int numNodes = net->getNumNodes();
|
||||||
@ -819,6 +783,9 @@ public:
|
|||||||
}
|
}
|
||||||
return retVal;
|
return retVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int cast, gather;
|
||||||
};
|
};
|
||||||
|
|
||||||
/* Constant folding shape for Expand.
|
/* Constant folding shape for Expand.
|
||||||
@ -838,12 +805,12 @@ public:
|
|||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int values = addNodeToMatch("");
|
int values = addNodeToMatch("");
|
||||||
int init = addNodeToMatch("ConstantOfShape", values);
|
init = addNodeToMatch("ConstantOfShape", values);
|
||||||
int coeff = addNodeToMatch("Constant");
|
int coeff = addNodeToMatch("Constant");
|
||||||
int mul = addNodeToMatch("Mul", init, coeff);
|
mul = addNodeToMatch("Mul", init, coeff);
|
||||||
int shape = addNodeToMatch("Constant");
|
int shape = addNodeToMatch("Constant");
|
||||||
int condition = addNodeToMatch("Equal", shape, mul);
|
condition = addNodeToMatch("Equal", shape, mul);
|
||||||
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
|
where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
|
||||||
addNodeToMatch("Expand", input, where);
|
addNodeToMatch("Expand", input, where);
|
||||||
setFusedNode("Expand", input, shape);
|
setFusedNode("Expand", input, shape);
|
||||||
}
|
}
|
||||||
@ -872,53 +839,28 @@ public:
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<int64_t> extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
|
||||||
{
|
|
||||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
|
||||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
|
||||||
Mat mat_constant;
|
|
||||||
if (initializer_id != -1) // initializer
|
|
||||||
{
|
|
||||||
mat_constant = onnx_net->getMatFromInitializer(initializer_id);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
|
||||||
int constant_id = getInputNodeId(net, node, input_id);
|
|
||||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
|
||||||
opencv_onnx::NodeProto* constant_node = constant_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
|
||||||
opencv_onnx::TensorProto constant_proto = constant_node->attribute(0).t();
|
|
||||||
mat_constant = getMatFromTensor(constant_proto);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<int64_t> retvals{mat_constant.begin<int>(), mat_constant.end<int>()};
|
|
||||||
return retvals;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE {
|
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
|
||||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) {
|
|
||||||
int64_t value_ConstantOfShape;
|
int64_t value_ConstantOfShape;
|
||||||
if (!extractValue(net, matchedNodesIds[0], value_ConstantOfShape)) {
|
if (!extractValue(net, matchedNodesIds[init], value_ConstantOfShape)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
std::vector<int64_t> input_ConstantOfShape = extractConstant(net, matchedNodesIds[0], 0);
|
std::vector<int> input_ConstantOfShape = extractConstant(net, matchedNodesIds[init], 0);
|
||||||
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
|
if (input_ConstantOfShape.size() != static_cast<size_t>(1)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
std::vector<int> B_Mul = extractConstant(net, matchedNodesIds[mul], 1);
|
||||||
auto B_Mul = extractConstant(net, matchedNodesIds[1], 1);
|
|
||||||
if (B_Mul.size() != static_cast<size_t>(1)) {
|
if (B_Mul.size() != static_cast<size_t>(1)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto A_Equal = extractConstant(net, matchedNodesIds[2], 0);
|
std::vector<int> A_Equal = extractConstant(net, matchedNodesIds[condition], 0);
|
||||||
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
|
if (A_Equal.size() != static_cast<size_t>(input_ConstantOfShape[0])) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto Y_Where = extractConstant(net, matchedNodesIds[3], 2);
|
std::vector<int> Y_Where = extractConstant(net, matchedNodesIds[where], 2);
|
||||||
if (Y_Where.size() != A_Equal.size()) {
|
if (Y_Where.size() != A_Equal.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -969,6 +911,9 @@ public:
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int init, mul, condition, where;
|
||||||
};
|
};
|
||||||
|
|
||||||
class MishSubgraph : public Subgraph
|
class MishSubgraph : public Subgraph
|
||||||
@ -979,7 +924,7 @@ public:
|
|||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int softplus = addNodeToMatch("Softplus", input);
|
int softplus = addNodeToMatch("Softplus", input);
|
||||||
int tanh = addNodeToMatch("Tanh", softplus);
|
int tanh = addNodeToMatch("Tanh", softplus);
|
||||||
addNodeToMatch("Mul", input, tanh);
|
addNodeToMatch("Mul", tanh, input);
|
||||||
setFusedNode("Mish", input);
|
setFusedNode("Mish", input);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -999,20 +944,6 @@ public:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class SoftplusSubgraph2: public Subgraph
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
SoftplusSubgraph2()
|
|
||||||
{
|
|
||||||
int input = addNodeToMatch("");
|
|
||||||
int exp = addNodeToMatch("Exp", input);
|
|
||||||
int addVal = addNodeToMatch("");
|
|
||||||
int add = addNodeToMatch("Add", exp, addVal);
|
|
||||||
addNodeToMatch("Log", add);
|
|
||||||
setFusedNode("Softplus", input);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class MulCastSubgraph : public Subgraph
|
class MulCastSubgraph : public Subgraph
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -1248,7 +1179,6 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
|
|||||||
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
|
||||||
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
subgraphs.push_back(makePtr<ExpandSubgraph>());
|
||||||
subgraphs.push_back(makePtr<SoftplusSubgraph>());
|
subgraphs.push_back(makePtr<SoftplusSubgraph>());
|
||||||
subgraphs.push_back(makePtr<SoftplusSubgraph2>());
|
|
||||||
subgraphs.push_back(makePtr<MishSubgraph>());
|
subgraphs.push_back(makePtr<MishSubgraph>());
|
||||||
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
|
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
|
||||||
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
|
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
|
||||||
|
@ -98,6 +98,14 @@ public:
|
|||||||
net.mutable_node()->DeleteSubrange(idx, 1);
|
net.mutable_node()->DeleteSubrange(idx, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual inline bool isCommutativeOp(const std::string& type) const CV_OVERRIDE
|
||||||
|
{
|
||||||
|
return type == "Add" || type == "Sum" ||
|
||||||
|
type == "Mul" || type == "Prod" ||
|
||||||
|
type == "Max" || type == "Maximum" || type == "Minimum" ||
|
||||||
|
type == "Mean" || type == "SquaredDifference";
|
||||||
|
}
|
||||||
|
|
||||||
tensorflow::GraphDef& net;
|
tensorflow::GraphDef& net;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -282,24 +290,26 @@ public:
|
|||||||
{
|
{
|
||||||
int input = addNodeToMatch("");
|
int input = addNodeToMatch("");
|
||||||
int relu = addNodeToMatch("Relu", input);
|
int relu = addNodeToMatch("Relu", input);
|
||||||
int maxValue = addNodeToMatch("Const");
|
maxValueId = addNodeToMatch("Const");
|
||||||
int clipValue = addNodeToMatch("Const");
|
int clipValue = addNodeToMatch("Const");
|
||||||
int minimum = addNodeToMatch("Minimum", relu, maxValue);
|
int minimum = addNodeToMatch("Minimum", relu, maxValueId);
|
||||||
addNodeToMatch("Maximum", minimum, clipValue);
|
addNodeToMatch("Maximum", minimum, clipValue);
|
||||||
|
|
||||||
setFusedNode("Relu6", input);
|
setFusedNode("Relu6", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
if (!Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
if (!Subgraph::match(net, nodeId, matchedNodesIds))
|
||||||
return false;
|
return false;
|
||||||
tensorflow::NodeDef* node = net->getNode(matchedNodesIds.front() + 1).dynamicCast<TFNodeWrapper>()->node;
|
tensorflow::NodeDef* node = net->getNode(matchedNodesIds[maxValueId]).dynamicCast<TFNodeWrapper>()->node;
|
||||||
Mat maxValue = getTensorContent(node->attr().at("value").tensor());
|
Mat maxValue = getTensorContent(node->attr().at("value").tensor());
|
||||||
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
|
return maxValue.type() == CV_32FC1 && maxValue.total() == 1 && maxValue.at<float>(0) == 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int maxValueId;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Keras' reshape stores output shape in separate Const nodes by one value.
|
// Keras' reshape stores output shape in separate Const nodes by one value.
|
||||||
@ -328,15 +338,14 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
std::vector<int>& matchedNodesIds,
|
std::vector<int>& matchedNodesIds) CV_OVERRIDE
|
||||||
std::vector<int>& targetNodesIds) CV_OVERRIDE
|
|
||||||
{
|
{
|
||||||
Ptr<ImportNodeWrapper> node = net->getNode(nodeId);
|
Ptr<ImportNodeWrapper> node = net->getNode(nodeId);
|
||||||
if (node->getNumInputs() == 0)
|
if (node->getNumInputs() == 0)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
inpName = node->getInputName(0);
|
inpName = node->getInputName(0);
|
||||||
return Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds);
|
return Subgraph::match(net, nodeId, matchedNodesIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user