mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Merge pull request #24808 from fengyuentau:fix_layernorm
dnn: no layer norm fusion if axes.back() is not the axis of last dimension #24808 Merge with https://github.com/opencv/opencv_extra/pull/1137 Resolves https://github.com/opencv/opencv/issues/24797 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [ ] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
75dc334d39
commit
7fb336322d
@ -86,6 +86,7 @@ public:
|
|||||||
int getTensorShapeSize(int node_id, int node_input_id) {
|
int getTensorShapeSize(int node_id, int node_input_id) {
|
||||||
const auto node = getNode(node_id);
|
const auto node = getNode(node_id);
|
||||||
const auto &input_name = node->getInputName(node_input_id);
|
const auto &input_name = node->getInputName(node_input_id);
|
||||||
|
// try to get from value_info
|
||||||
for (int i = 0; i < net.value_info_size(); i++) {
|
for (int i = 0; i < net.value_info_size(); i++) {
|
||||||
const auto value_info = net.value_info(i);
|
const auto value_info = net.value_info(i);
|
||||||
if (value_info.name() == input_name) {
|
if (value_info.name() == input_name) {
|
||||||
@ -97,6 +98,18 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// try to get from input
|
||||||
|
for (int i = 0; i < net.input_size(); i++) {
|
||||||
|
const auto input = net.input(i);
|
||||||
|
if (input.name() == input_name) {
|
||||||
|
if (input.has_type() && input.type().has_tensor_type() &&
|
||||||
|
input.type().tensor_type().has_shape()) {
|
||||||
|
return input.type().tensor_type().shape().dim_size();
|
||||||
|
} else {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -660,6 +673,10 @@ private:
|
|||||||
[Input] -> LayerNorm -> [Output]
|
[Input] -> LayerNorm -> [Output]
|
||||||
\
|
\
|
||||||
[weight], [bias]
|
[weight], [bias]
|
||||||
|
|
||||||
|
Note: axes of ReduceMean must be:
|
||||||
|
- last element is the axis of last dimension (-1 or (input_ndims - 1))
|
||||||
|
- a list of adjacent axes, e.g. [1, 2, 3, ..., input_ndims - 1]
|
||||||
*/
|
*/
|
||||||
class LayerNormSubGraph : public Subgraph
|
class LayerNormSubGraph : public Subgraph
|
||||||
{
|
{
|
||||||
@ -683,19 +700,22 @@ public:
|
|||||||
setFusedNode("LayerNormalization", input);
|
setFusedNode("LayerNormalization", input);
|
||||||
}
|
}
|
||||||
|
|
||||||
static float extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
|
static std::vector<int64_t> extractAxis(const Ptr<ImportGraphWrapper>& net, int node_id)
|
||||||
{
|
{
|
||||||
|
// TODO: consider ReduceMean-18 which has axes as one of the inputs instead of attributes
|
||||||
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
|
Ptr<ImportNodeWrapper> mean_ptr = net->getNode(node_id);
|
||||||
opencv_onnx::NodeProto* mean_node = mean_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
opencv_onnx::NodeProto* mean_node = mean_ptr.dynamicCast<ONNXNodeWrapper>()->node;
|
||||||
int axis_ = -1;
|
std::vector<int64_t> axes;
|
||||||
for (int i = 0; i < mean_node->attribute_size(); i++)
|
for (int i = 0; i < mean_node->attribute_size(); i++)
|
||||||
{
|
{
|
||||||
opencv_onnx::AttributeProto attr = mean_node->attribute(i);
|
opencv_onnx::AttributeProto attr = mean_node->attribute(i);
|
||||||
if (attr.name() != "axes")
|
if (attr.name() != "axes")
|
||||||
continue;
|
continue;
|
||||||
axis_ = static_cast<int>(attr.ints(0));
|
for (int j = 0; j < attr.ints_size(); j++) {
|
||||||
|
axes.push_back(attr.ints(j));
|
||||||
}
|
}
|
||||||
return axis_;
|
}
|
||||||
|
return axes;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
|
||||||
@ -707,11 +727,31 @@ public:
|
|||||||
if (pow_exp - 2 > 1e-5) // not pow(2)
|
if (pow_exp - 2 > 1e-5) // not pow(2)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
int axis_mean1 = extractAxis(net, matchedNodesIds[mean]);
|
std::vector<int64_t> axes = extractAxis(net, matchedNodesIds[mean]);
|
||||||
int axis_mean2 = extractAxis(net, matchedNodesIds[mean1]);
|
// check whether it is -1 or last_axis or [axis, ..., last_axis]
|
||||||
if (axis_mean1 != axis_mean2)
|
int64_t input_ndims = static_cast<int64_t>(net.dynamicCast<ONNXGraphWrapper>()->getTensorShapeSize(matchedNodesIds[mean], 0));
|
||||||
|
if (input_ndims == -1) {
|
||||||
|
return false; // input shape unknown
|
||||||
|
}
|
||||||
|
// assume that axes are sorted in ascending order, e.g. [0, 1, 2, 3] or [-3, -2, -1]
|
||||||
|
if (axes.back() != -1 && axes.back() != (input_ndims - 1)) {
|
||||||
return false;
|
return false;
|
||||||
axis = axis_mean1;
|
}
|
||||||
|
for (size_t i = 0; i < axes.size() - 1; i++) {
|
||||||
|
if (axes[i] - axes[i + 1] != -1) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> axes1 = extractAxis(net, matchedNodesIds[mean1]);
|
||||||
|
if (axes.size() != axes1.size())
|
||||||
|
return false;
|
||||||
|
for (size_t i = 0; i < axes.size(); i++) {
|
||||||
|
if (((axes[i] + input_ndims) % input_ndims) != ((axes1[i] + input_ndims) % input_ndims)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
axis = axes[0];
|
||||||
|
|
||||||
epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
|
epsilon = extractConstant(net, matchedNodesIds[add], 1).at<float>(0);
|
||||||
|
|
||||||
|
@ -47,6 +47,10 @@ TEST_F(Test_Graph_Simplifier, LayerNormSubGraph) {
|
|||||||
test("layer_norm_expanded_with_initializers", "LayerNormalization");
|
test("layer_norm_expanded_with_initializers", "LayerNormalization");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(Test_Graph_Simplifier, LayerNormNoFusionSubGraph) {
|
||||||
|
test("layer_norm_no_fusion", std::vector<std::string>{"NaryEltwise", "Reduce", "Sqrt"});
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(Test_Graph_Simplifier, ResizeSubgraph) {
|
TEST_F(Test_Graph_Simplifier, ResizeSubgraph) {
|
||||||
/* Test for 6 subgraphs:
|
/* Test for 6 subgraphs:
|
||||||
- GatherCastSubgraph
|
- GatherCastSubgraph
|
||||||
|
@ -3024,6 +3024,10 @@ TEST_P(Test_ONNX_nets, VitTrack) {
|
|||||||
normAssert(ref_output3, outputs[2], "VitTrack output3");
|
normAssert(ref_output3, outputs[2], "VitTrack output3");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_ONNX_layers, LayerNormNoFusion) {
|
||||||
|
testONNXModels("layer_norm_no_fusion");
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
|
||||||
|
|
||||||
}} // namespace
|
}} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user