Merge pull request #23296 from fengyuentau:fix_identifying_constant

Fix identifying initializers in ONNX graph simplification #23296

Fixes https://github.com/opencv/opencv/issues/23295

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Yuantao Feng 2023-04-06 15:35:31 +03:00 committed by GitHub
parent 8336a96cb9
commit 3a83a35ab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -82,7 +82,8 @@ public:
for (int i = 0; i < numInitializers; ++i)
if (net.initializer(i).name() == node_input_name)
return i;
CV_Error(Error::StsParseError, "Initializer with name " + node_input_name + " not found");
// CV_Error(Error::StsParseError, "Initializer with name " + node_input_name + " not found");
return -1;
}
Mat getMatFromInitializer(int idx)
@ -158,24 +159,17 @@ public:
setFusedNode("Gelu", input);
}
static bool isWithInitializer(const std::vector<int>& matchedNodesIds)
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
// if node.getType() is Constant, Constant nodes are placed between other nodes
if (matchedNodesIds[2] - matchedNodesIds[1] != 1)
return false;
// if Initializer, there is no Constant node between other nodes
return true;
}
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
{
if (withInitializer)
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
} else {
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
@ -192,21 +186,19 @@ public:
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{
bool withInitializer = isWithInitializer(matchedNodesIds);
// Check Div[B=sqrt(2)]
float divisor = extractConstant(net, matchedNodesIds[0], 1, withInitializer);
if (divisor - M_SQRT2 >= 1e-6)
float divisor = extractConstant(net, matchedNodesIds[0], 1);
if (std::fabs(divisor - M_SQRT2) >= std::numeric_limits<float>::epsilon())
return false;
// Check Add[B=1]
float add_const = extractConstant(net, matchedNodesIds[2], 1, withInitializer);
if (add_const - 1.f >= 1e-6)
float add_const = extractConstant(net, matchedNodesIds[2], 1);
if (std::fabs(add_const - 1.f) >= std::numeric_limits<float>::epsilon())
return false;
// Check Mul[B=0.5]
float mul_const = extractConstant(net, matchedNodesIds[4], 1, withInitializer);
if (mul_const - 0.5f >= 1e-6)
float mul_const = extractConstant(net, matchedNodesIds[4], 1);
if (std::fabs(mul_const - 0.5f) >= std::numeric_limits<float>::epsilon())
return false;
return true;
@ -247,24 +239,17 @@ public:
setFusedNode("GeluApproximation", input);
}
static bool isWithInitializer(const std::vector<int>& matchedNodesIds)
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
// if node.getType() is Constant, Constant nodes are placed between other nodes
if (matchedNodesIds[2] - matchedNodesIds[1] != 1)
return false;
// if Initializer, there is no Constant node between other nodes
return true;
}
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
{
if (withInitializer)
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
} else {
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
@ -281,25 +266,23 @@ public:
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{
bool withInitializer = isWithInitializer(matchedNodesIds);
// Check Mul[A=0.044714998453855515]
float coef = extractConstant(net, matchedNodesIds[2], 0, withInitializer);
float coef = extractConstant(net, matchedNodesIds[2], 0);
if (coef - 0.044714998453855515 >= 1e-6)
return false;
// Check Mul[A=sqrt(2/pie)]
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0, withInitializer);
float sqrt_2_pie = extractConstant(net, matchedNodesIds[4], 0);
if (sqrt_2_pie - 0.7978845834732056 >= 1e-6)
return false;
// Check Add[A=1]
float add_const = extractConstant(net, matchedNodesIds[6], 0, withInitializer);
float add_const = extractConstant(net, matchedNodesIds[6], 0);
if (add_const - 1.f >= 1e-6)
return false;
// Check Mul[A=0.5]
float mul_const = extractConstant(net, matchedNodesIds[8], 0, withInitializer);
float mul_const = extractConstant(net, matchedNodesIds[8], 0);
if (mul_const - 0.5f >= 1e-6)
return false;
@ -309,15 +292,25 @@ public:
}
};
/* Fusion for LayerNormalization.
Graph before fusion
+-> ReduceMean ->+
| |
[Input] -------> Sub -----------------------------------------------> Div -> Mul(B=weight) -> Add(B=bias) -> [Output]
| |
+-> Pow(Y=2) -> ReduceMean -> Add(B=epsilon) -> Sqrt ->+
Graph after fusion
[Input] -> LayerNorm -> [Output]
\
[weight], [bias]
*/
class LayerNormSubGraph : public Subgraph
{
public:
LayerNormSubGraph() : axis(-1), epsilon(1e-5)
{
// -> ReduceMean -> -> Pow(2) -> ReduceMean -> Add(epsilon) -> Sqrt ->
// x Sub Div -> Mul(scale) -> Add(bias)
// ---------------> ------------------------------------------------->
// NOTE: Pow(2), Add(epsilon), Mul(scale), add(bias) can have constants as op_type Constant or Initializer
int input = addNodeToMatch("");
int mean = addNodeToMatch("ReduceMean", input);
@ -335,24 +328,17 @@ public:
setFusedNode("LayerNormalization", input);
}
static bool isWithInitializer(const std::vector<int>& matchedNodesIds)
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
// if node.getType() is Constant, Constant nodes are placed between other nodes
if (matchedNodesIds[2] - matchedNodesIds[1] != 1)
return false;
// if Initializer, there is no nodes for constant between other nodes
return true;
}
static float extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
{
if (withInitializer)
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1) // initializer
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
Mat const_mat = onnx_net->getMatFromInitializer(initializer_id);
return *const_mat.ptr<float>();
} else {
}
else
{
const Ptr<ImportNodeWrapper> node = net->getNode(node_id);
int constant_id = getInputNodeId(net, node, input_id);
Ptr<ImportNodeWrapper> constant_ptr = net->getNode(constant_id);
@ -378,14 +364,16 @@ public:
return axis_;
}
static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id, bool withInitializer)
static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
if (withInitializer)
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
return onnx_net->getNameOfInitializer(initializer_id);
} else {
}
else
{
const auto node = net->getNode(node_id);
return node->getInputName(input_id);
}
@ -397,9 +385,7 @@ public:
{
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
{
withInitializer = isWithInitializer(matchedNodesIds);
float pow_exp = extractConstant(net, matchedNodesIds[2], 1, withInitializer);
float pow_exp = extractConstant(net, matchedNodesIds[2], 1);
if (pow_exp - 2 > 1e-5) // not pow(2)
return false;
@ -409,10 +395,10 @@ public:
return false;
axis = axis_mean1;
epsilon = extractConstant(net, matchedNodesIds[4], 1, withInitializer);
epsilon = extractConstant(net, matchedNodesIds[4], 1);
weight_name = getInputName(net, matchedNodesIds[7], 1, withInitializer);
bias_name = getInputName(net, matchedNodesIds[8], 1, withInitializer);
weight_name = getInputName(net, matchedNodesIds[7], 1);
bias_name = getInputName(net, matchedNodesIds[8], 1);
return true;
}
@ -440,7 +426,6 @@ public:
protected:
int axis;
float epsilon;
bool withInitializer;
std::string weight_name;
std::string bias_name;
};