mirror of
https://github.com/opencv/opencv.git
synced 2024-11-28 21:20:18 +08:00
Merge pull request #23296 from fengyuentau:fix_identifying_constant
Fix identifying initializers in ONNX graph simplification #23296 Fixes https://github.com/opencv/opencv/issues/23295 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
8336a96cb9
commit
3a83a35ab0
@ -82,7 +82,8 @@ public:
|
||||
for (int i = 0; i < numInitializers; ++i)
|
||||
if (net.initializer(i).name() == node_input_name)
|
||||
return i;
|
||||
CV_Error(Error::StsParseError, "Initializer with name " + node_input_name + " not found");
|
||||
// CV_Error(Error::StsParseError, "Initializer with name " + node_input_name + " not found");
|
||||
return -1;
|
||||
}
|
||||
|
||||
Mat getMatFromInitializer(int idx)
|
||||
@ -158,24 +159,17 @@ public:
|
||||
setFusedNode("Gelu", input);
|
||||
}
|
||||
|
||||
static bool isWithInitializer(const std::vector<int>& matchedNodesIds)
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
// if node.getType() is Constant, Constant nodes are placed between other nodes
|
||||
if (matchedNodesIds[2] - matchedNodesIds[1] != 1)
|
||||
return false;
|
||||
// if Initializer, there is no Constant node between other nodes
|
||||
return true;
|
||||
}
|
||||
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
|
||||
{
|
||||
if (withInitializer)
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
||||
return *const_mat.ptr<float>();
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
@ -192,21 +186,19 @@ public:
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
{
|
||||
bool withInitializer = isWithInitializer(matchedNodesIds);
|
||||
|
||||
// Check Div[B=sqrt(2)]
|
||||
float divisor = extractConstant(net, matchedNodesIds[0], 1, withInitializer);
|
||||
if (divisor - M_SQRT2 >= 1e-6)
|
||||
float divisor = extractConstant(net, matchedNodesIds[0], 1);
|
||||
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
|
||||
return false;
|
||||
|
||||
// Check Add[B=1]
|
||||
float add_const = extractConstant(net, matchedNodesIds[2], 1, withInitializer);
|
||||
if (add_const - 1.f >= 1e-6)
|
||||
float add_const = extractConstant(net, matchedNodesIds[2], 1);
|
||||
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
|
||||
return false;
|
||||
|
||||
// Check Mul[B=0.5]
|
||||
float mul_const = extractConstant(net, matchedNodesIds[4], 1, withInitializer);
|
||||
if (mul_const - 0.5f >= 1e-6)
|
||||
float mul_const = extractConstant(net, matchedNodesIds[4], 1);
|
||||
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
|
||||
return false;
|
||||
|
||||
return true;
|
||||
@ -247,24 +239,17 @@ public:
|
||||
setFusedNode("GeluApproximation", input);
|
||||
}
|
||||
|
||||
static bool isWithInitializer(const std::vector<int>& matchedNodesIds)
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
// if node.getType() is Constant, Constant nodes are placed between other nodes
|
||||
if (matchedNodesIds[2] - matchedNodesIds[1] != 1)
|
||||
return false;
|
||||
// if Initializer, there is no Constant node between other nodes
|
||||
return true;
|
||||
}
|
||||
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
|
||||
{
|
||||
if (withInitializer)
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
||||
return *const_mat.ptr<float>();
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
@ -281,25 +266,23 @@ public:
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
{
|
||||
bool withInitializer = isWithInitializer(matchedNodesIds);
|
||||
|
||||
// Check Mul[A=0.044714998453855515]
|
||||
float coef = extractConstant(net, matchedNodesIds[2], 0, withInitializer);
|
||||
float coef = extractConstant(net, matchedNodesIds[2], 0);
|
||||
if (coef - 0.044714998453855515 >= 1e-6)
|
||||
return false;
|
||||
|
||||
// Check Mul[A=sqrt(2/pie)]
|
||||
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0, withInitializer);
|
||||
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0);
|
||||
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
|
||||
return false;
|
||||
|
||||
// Check Add[A=1]
|
||||
float add_const = extractConstant(net, matchedNodesIds[6], 0, withInitializer);
|
||||
float add_const = extractConstant(net, matchedNodesIds[6], 0);
|
||||
if (add_const - 1.f >= 1e-6)
|
||||
return false;
|
||||
|
||||
// Check Mul[A=0.5]
|
||||
float mul_const = extractConstant(net, matchedNodesIds[8], 0, withInitializer);
|
||||
float mul_const = extractConstant(net, matchedNodesIds[8], 0);
|
||||
if (mul_const - 0.5f >= 1e-6)
|
||||
return false;
|
||||
|
||||
@ -309,15 +292,25 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/* Fusion for LayerNormalization.
|
||||
|
||||
Graph before fusion
|
||||
+-> ReduceMean ->+
|
||||
| |
|
||||
[Input] -------> Sub -----------------------------------------------> Div -> Mul(B=weight) -> Add(B=bias) -> [Output]
|
||||
| |
|
||||
+-> Pow(Y=2) -> ReduceMean -> Add(B=epsilon) -> Sqrt ->+
|
||||
|
||||
Graph after fusion
|
||||
[Input] -> LayerNorm -> [Output]
|
||||
\
|
||||
[weight], [bias]
|
||||
*/
|
||||
class LayerNormSubGraph : public Subgraph
|
||||
{
|
||||
public:
|
||||
LayerNormSubGraph() : axis(-1), epsilon(1e-5)
|
||||
{
|
||||
// -> ReduceMean -> -> Pow(2) -> ReduceMean -> Add(epsilon) -> Sqrt ->
|
||||
// x Sub Div -> Mul(scale) -> Add(bias)
|
||||
// ---------------> ------------------------------------------------->
|
||||
// NOTE: Pow(2), Add(epsilon), Mul(scale), add(bias) can have constants as op_type Constant or Initializer
|
||||
int input = addNodeToMatch("");
|
||||
int mean = addNodeToMatch("ReduceMean", input);
|
||||
|
||||
@ -335,24 +328,17 @@ public:
|
||||
setFusedNode("LayerNormalization", input);
|
||||
}
|
||||
|
||||
static bool isWithInitializer(const std::vector<int>& matchedNodesIds)
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
// if node.getType() is Constant, Constant nodes are placed between other nodes
|
||||
if (matchedNodesIds[2] - matchedNodesIds[1] != 1)
|
||||
return false;
|
||||
// if Initializer, there is no nodes for constant between other nodes
|
||||
return true;
|
||||
}
|
||||
|
||||
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
|
||||
{
|
||||
if (withInitializer)
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1) // initializer
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
|
||||
return *const_mat.ptr<float>();
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
|
||||
int constant_id = getInputNodeId(net, node, input_id);
|
||||
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
|
||||
@ -378,14 +364,16 @@ public:
|
||||
return axis_;
|
||||
}
|
||||
|
||||
static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
|
||||
static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
|
||||
{
|
||||
if (withInitializer)
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
if (initializer_id != -1)
|
||||
{
|
||||
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
|
||||
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
|
||||
return onnx_net->getNameOfInitializer(initializer_id);
|
||||
} else {
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto node = net->getNode(node_id);
|
||||
return node->getInputName(input_id);
|
||||
}
|
||||
@ -397,9 +385,7 @@ public:
|
||||
{
|
||||
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
|
||||
{
|
||||
withInitializer = isWithInitializer(matchedNodesIds);
|
||||
|
||||
float pow_exp = extractConstant(net, matchedNodesIds[2], 1, withInitializer);
|
||||
float pow_exp = extractConstant(net, matchedNodesIds[2], 1);
|
||||
if (pow_exp - 2 > 1e-5) // not pow(2)
|
||||
return false;
|
||||
|
||||
@ -409,10 +395,10 @@ public:
|
||||
return false;
|
||||
axis = axis_mean1;
|
||||
|
||||
epsilon = extractConstant(net, matchedNodesIds[4], 1, withInitializer);
|
||||
epsilon = extractConstant(net, matchedNodesIds[4], 1);
|
||||
|
||||
weight_name = getInputName(net, matchedNodesIds[7], 1, withInitializer);
|
||||
bias_name = getInputName(net, matchedNodesIds[8], 1, withInitializer);
|
||||
weight_name = getInputName(net, matchedNodesIds[7], 1);
|
||||
bias_name = getInputName(net, matchedNodesIds[8], 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -440,7 +426,6 @@ public:
|
||||
protected:
|
||||
int axis;
|
||||
float epsilon;
|
||||
bool withInitializer;
|
||||
std::string weight_name;
|
||||
std::string bias_name;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user