mirror of
https://github.com/opencv/opencv.git
synced 2024-12-05 09:49:12 +08:00
0521a3a384
dnn: add attention layer #24476 Resolves #24609 Merge with: https://github.com/opencv/opencv_extra/pull/1128. Attention operator spec from onnxruntime: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention. TODO: - [x] benchmark (before this PR vs. with this PR vs. ORT). - [x] Layer fusion: Take care Slice with end=INT64_MAX. - [x] Layer fusion: match more potential attention (VIT) patterns. - [x] Single-head attention is supported. - [x] Test AttentionSubgraph fusion. - [x] Add acc tests for VIT_B_32 and VitTrack - [x] Add perf tests for VIT_B_32 and VitTrack ## Benchmarks Platform: Macbook Air M1. ### Attention Subgraph Input scale: [1, 197, 768]. | | mean (ms) | median (ms) | min (ms) | | ---------------------- | --------- | ----------- | -------- | | w/ Attention (this PR) | 3.75 | 3.68 | 3.22 | | w/o Attention | 9.06 | 9.01 | 8.24 | | ORT (python) | 4.32 | 2.63 | 2.50 | ### ViTs All data in millisecond (ms). | ViTs | With Attention | Without Attention | ORT | | -------- | -------------- | ----------------- | ------ | | vit_b_16 | 302.77 | 365.35 | 109.70 | | vit_b_32 | 89.92 | 116.22 | 30.36 | | vit_l_16 | 1593.32 | 1730.74 | 419.92 | | vit_l_32 | 468.11 | 577.41 | 134.12 | | VitTrack | 3.80 | 3.87 | 2.25 | ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
143 lines
5.2 KiB
C++
143 lines
5.2 KiB
C++
// This file is part of OpenCV project.
|
|
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
|
// of this distribution and at http://opencv.org/license.html.
|
|
|
|
#include "test_precomp.hpp"
|
|
|
|
namespace opencv_test { namespace {
|
|
|
|
class Test_Graph_Simplifier : public ::testing::Test {
|
|
public:
|
|
bool required;
|
|
|
|
Test_Graph_Simplifier() : required(true) {}
|
|
|
|
void test_conformance(const std::string &basename, const std::string &expected_layer) {
|
|
test(basename + std::string("/model"), std::vector<std::string>{expected_layer}, std::string("dnn/onnx/conformance/node/"));
|
|
}
|
|
|
|
void test(const std::string &basename, const std::string &expected_layer) {
|
|
test(basename, std::vector<std::string>{expected_layer});
|
|
}
|
|
|
|
void test(const std::string &basename, const std::vector<std::string> &expected_layers, const std::string &model_path_prefix = std::string("dnn/onnx/models/")) {
|
|
std::string model_path = findDataFile(model_path_prefix + basename + std::string(".onnx"), required);
|
|
auto net = readNet(model_path);
|
|
std::vector<std::string> layers;
|
|
net.getLayerTypes(layers);
|
|
|
|
// remove Const, Identity (output layer), __NetInputLayer__ (input layer)
|
|
layers.erase(std::remove_if(layers.begin(), layers.end(), [] (const std::string l) { return l == "Const" || l == "Identity" || l == "__NetInputLayer__"; }), layers.end());
|
|
|
|
EXPECT_EQ(layers, expected_layers);
|
|
}
|
|
};
|
|
|
|
TEST_F(Test_Graph_Simplifier, GeluSubGraph) {
|
|
test("gelu", "Gelu");
|
|
test("bias_gelu", std::vector<std::string>{"Gelu", "NaryEltwise"});
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, GeluApproximationSubGraph) {
|
|
test("gelu_approximation", "GeluApproximation");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, LayerNormSubGraph) {
|
|
test("layer_norm_expanded", "LayerNormalization");
|
|
test("layer_norm_expanded_with_initializers", "LayerNormalization");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, ResizeSubgraph) {
|
|
/* Test for 6 subgraphs:
|
|
- GatherCastSubgraph
|
|
- MulCastSubgraph
|
|
- UpsampleSubgraph
|
|
- ResizeSubgraph1
|
|
- ResizeSubgraph2
|
|
- ResizeSubgraph3
|
|
*/
|
|
test("upsample_unfused_torch1.2", std::vector<std::string>{"BatchNorm", "Resize"});
|
|
test("resize_nearest_unfused_opset11_torch1.3", std::vector<std::string>{"BatchNorm", "Convolution", "Resize"});
|
|
test("resize_nearest_unfused_opset11_torch1.4", std::vector<std::string>{"BatchNorm", "Convolution", "Resize"});
|
|
test("upsample_unfused_opset9_torch1.4", std::vector<std::string>{"BatchNorm", "Convolution", "Resize"});
|
|
test("two_resizes_with_shared_subgraphs", std::vector<std::string>{"NaryEltwise", "Resize"});
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, SoftmaxSubgraph) {
|
|
/* Test for 3 subgraphs
|
|
- SoftMaxSubgraph
|
|
- SoftMaxSubgraph2 (conformance)
|
|
- LogSoftMaxSubgraph (conformance)
|
|
*/
|
|
test("softmax_unfused", "Softmax");
|
|
test_conformance("test_softmax_example_expanded", "Softmax");
|
|
test_conformance("test_softmax_axis_2_expanded", "Softmax");
|
|
test_conformance("test_softmax_default_axis_expanded", "Softmax");
|
|
test_conformance("test_softmax_axis_0_expanded", "Softmax");
|
|
test_conformance("test_softmax_axis_1_expanded", "Softmax");
|
|
test_conformance("test_softmax_large_number_expanded", "Softmax");
|
|
test_conformance("test_softmax_negative_axis_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_axis_2_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_example_1_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_negative_axis_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_axis_0_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_axis_1_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_large_number_expanded", "Softmax");
|
|
test_conformance("test_logsoftmax_default_axis_expanded", "Softmax");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, HardSwishSubgraph) {
|
|
test_conformance("test_hardswish_expanded", "HardSwish");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, CeluSubgraph) {
|
|
test_conformance("test_celu_expanded", "Celu");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, NormalizeSubgraph) {
|
|
/* Test for 6 subgraphs
|
|
- NormalizeSubgraph1
|
|
- NormalizeSubgraph2
|
|
- NormalizeSubgraph2_2
|
|
- NormalizeSubgraph3
|
|
- NormalizeSubgraph4
|
|
- NormalizeSubgraph5
|
|
*/
|
|
test("reduceL2_subgraph_2", "Normalize");
|
|
test("reduceL2_subgraph", "Normalize");
|
|
test("normalize_fusion", "Normalize");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, BatchNormalizationSubgraph) {
|
|
/* Test for 2 subgraphs
|
|
- BatchNormalizationSubgraph1
|
|
- BatchNormalizationSubgraph2
|
|
*/
|
|
test("frozenBatchNorm2d", "BatchNorm");
|
|
test("batch_norm_subgraph", "BatchNorm");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, ExpandSubgraph) {
|
|
test("expand_neg_batch", "Expand");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, MishSubgraph) {
|
|
/* Test for 2 subgraphs
|
|
- SoftplusSubgraph
|
|
- MishSubgraph
|
|
*/
|
|
test("mish_no_softplus", "Mish");
|
|
test("mish", "Mish");
|
|
}
|
|
|
|
TEST_F(Test_Graph_Simplifier, AttentionSubgraph) {
|
|
/* Test for 2 subgraphs
|
|
- AttentionSubgraph
|
|
- AttentionSingleHeadSubgraph
|
|
*/
|
|
test("attention", "Attention");
|
|
test("attention_single_head", "Attention");
|
|
}
|
|
|
|
}}
|