Merge pull request #24476 from fengyuentau:attention_layer

dnn: add attention layer #24476

Resolves #24609

Merge with: https://github.com/opencv/opencv_extra/pull/1128.

Attention operator spec from onnxruntime: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention.

TODO:
- [x] benchmark (before this PR vs. with this PR vs. ORT).
- [x] Layer fusion: Take care Slice with end=INT64_MAX.
- [x] Layer fusion: match more potential attention (VIT) patterns.
    - [x] Single-head attention is supported.
- [x] Test AttentionSubgraph fusion.
- [x] Add acc tests for VIT_B_32 and VitTrack
- [x] Add perf tests for VIT_B_32 and VitTrack

## Benchmarks

Platform: Macbook Air M1.

### Attention Subgraph

Input scale: [1, 197, 768].

|                        | mean (ms) | median (ms) | min (ms) |
| ---------------------- | --------- | ----------- | -------- |
| w/ Attention (this PR) | 3.75      | 3.68        | 3.22     |
| w/o Attention          | 9.06      | 9.01        | 8.24     |
| ORT (python)           | 4.32      | 2.63        | 2.50     |

### ViTs

All data in millisecond (ms).

| ViTs     | With Attention | Without Attention | ORT    |
| -------- | -------------- | ----------------- | ------ |
| vit_b_16 | 302.77         | 365.35            | 109.70 |
| vit_b_32 | 89.92          | 116.22            | 30.36  |
| vit_l_16 | 1593.32        | 1730.74           | 419.92 |
| vit_l_32 | 468.11         | 577.41            | 134.12 |
| VitTrack | 3.80           | 3.87              | 2.25   |

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Yuantao Feng 2023-12-21 00:35:07 +08:00 committed by GitHub
parent e64c5dc4c6
commit 0521a3a384
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 891 additions and 66 deletions

View File

@ -1178,6 +1178,11 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<InstanceNormLayer> create(const LayerParams &params);
};
class CV_EXPORTS AttentionLayer : public Layer {
public:
static Ptr<AttentionLayer> create(const LayerParams &params);
};
//! @}
//! @}
CV__DNN_INLINE_NS_END

View File

@ -739,6 +739,62 @@ PERF_TEST_P_(Layer_InstanceNorm, InstanceNorm)
test_layer({N, C, H, W});
}
struct Layer_Attention : public TestBaseWithParam<tuple<Backend, Target>> {
void test_layer(const std::vector<int> x_shape, const std::vector<int> qkv_hidden_sizes, const int num_heads) {
int backendId = get<0>(GetParam());
int targetId = get<1>(GetParam());
auto qk_hidden_size = qkv_hidden_sizes[0];
auto v_hidden_size = qkv_hidden_sizes[2];
auto input_hidden_size = x_shape[2];
auto hidden_size = qk_hidden_size + qk_hidden_size + v_hidden_size;
Mat x(x_shape, CV_32F);
Mat weight(std::vector<int>{input_hidden_size, hidden_size}, CV_32F);
Mat bias(std::vector<int>{hidden_size}, CV_32F);
randu(x, 0.f, 1.f);
randu(weight, 0.f, 1.f);
randu(bias, 0.f, 1.f);
LayerParams lp;
lp.type = "Attention";
lp.name = "testLayer";
lp.set("num_heads", num_heads);
lp.set("qkv_hidden_sizes", DictValue::arrayInt(qkv_hidden_sizes.data(), qkv_hidden_sizes.size()));
Net net;
int id = net.addLayerToPrev(lp.name, lp.type, lp);
net.connect(0, 0, id, 0);
net.connect(0, 1, id, 1);
net.connect(0, 2, id, 2);
{
std::vector<std::string> input_names{"x", "weight", "bias"};
net.setInputsNames(input_names);
net.setInput(x, input_names[0]);
net.setInput(weight, input_names[1]);
net.setInput(bias, input_names[2]);
net.setPreferableBackend(backendId);
net.setPreferableTarget(targetId);
Mat out = net.forward();
}
TEST_CYCLE()
{
Mat out = net.forward();
}
SANITY_CHECK_NOTHING();
}
};
PERF_TEST_P_(Layer_Attention, VisionTransformer) {
test_layer({1, 197, 768}, {768, 768, 768}, 12);
}
INSTANTIATE_TEST_CASE_P(/**/, Layer_Slice, dnnBackendsAndTargets(false, false));
INSTANTIATE_TEST_CASE_P(/**/, Layer_NaryEltwise, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
#ifdef HAVE_CUDA
@ -750,6 +806,7 @@ INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNorm, testing::Values(std::make_tuple(D
INSTANTIATE_TEST_CASE_P(/**/, Layer_LayerNormExpanded, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
INSTANTIATE_TEST_CASE_P(/**/, Layer_GatherElements, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
INSTANTIATE_TEST_CASE_P(/**/, Layer_InstanceNorm, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
INSTANTIATE_TEST_CASE_P(/**/, Layer_Attention, testing::Values(std::make_tuple(DNN_BACKEND_OPENCV, DNN_TARGET_CPU)));
typedef TestBaseWithParam<tuple<Vec4i, int, bool, tuple<Backend, Target> > > Layer_FullyConnected;

View File

@ -93,7 +93,6 @@ public:
}
};
PERF_TEST_P_(DNNTestNetwork, AlexNet)
{
processNet("dnn/bvlc_alexnet.caffemodel", "dnn/bvlc_alexnet.prototxt",
@ -391,17 +390,16 @@ PERF_TEST_P_(DNNTestNetwork, CRNN) {
processNet("", "dnn/text_recognition_CRNN_EN_2021sep.onnx", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, ViTTrack) {
PERF_TEST_P_(DNNTestNetwork, VitTrack) {
Mat inp1(cv::Size(128, 128), CV_32FC3);
Mat inp2(cv::Size(256, 256), CV_32FC3);
randu(inp1, 0.0f, 1.0f);
randu(inp2, 0.0f, 1.0f);
inp1 = blobFromImage(inp1, 1.0, Size(), Scalar(), false);
inp2 = blobFromImage(inp2, 1.0, Size(), Scalar(), false);
processNet("", "dnn/onnx/models/vitTracker.onnx", "", {std::make_tuple(inp1, "template"), std::make_tuple(inp2, "search")});
processNet("", "dnn/onnx/models/object_tracking_vittrack_2023sep.onnx", "", {std::make_tuple(inp1, "template"), std::make_tuple(inp2, "search")});
}
PERF_TEST_P_(DNNTestNetwork, EfficientDet_int8)
{
if (target != DNN_TARGET_CPU || (backend != DNN_BACKEND_OPENCV &&
@ -413,6 +411,10 @@ PERF_TEST_P_(DNNTestNetwork, EfficientDet_int8)
processNet("", "dnn/tflite/coco_efficientdet_lite0_v1_1.0_quant_2021_09_06.tflite", "", inp);
}
PERF_TEST_P_(DNNTestNetwork, VIT_B_32) {
processNet("", "dnn/onnx/models/vit_b_32.onnx", "", cv::Size(224, 224));
}
INSTANTIATE_TEST_CASE_P(/*nothing*/, DNNTestNetwork, dnnBackendsAndTargets());
} // namespace

View File

@ -162,6 +162,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(LayerNormalization, LayerNormLayer);
CV_DNN_REGISTER_LAYER_CLASS(Expand, ExpandLayer);
CV_DNN_REGISTER_LAYER_CLASS(InstanceNormalization, InstanceNormLayer);
CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer);
CV_DNN_REGISTER_LAYER_CLASS(Crop, CropLayer);
CV_DNN_REGISTER_LAYER_CLASS(Eltwise, EltwiseLayer);

View File

@ -0,0 +1,272 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include "cpu_kernels/fast_gemm.hpp"
#include "cpu_kernels/softmax.hpp"
#include <opencv2/dnn/shape_utils.hpp>
namespace cv { namespace dnn {
static void packWeight(size_t num_heads, size_t head_size, size_t input_hidden_size,
const float *weight_data, size_t hidden_size, std::vector<float> &packed_weight, const FastGemmOpt &opt) {
// num_heads * pack(head_size, input_hidden_size)
size_t pack_size = fastGemmPackBSize(head_size, input_hidden_size, opt);
size_t packed_weight_size = num_heads * pack_size;
packed_weight.resize(packed_weight_size, 0.f);
auto *packed_weight_data = packed_weight.data();
for (size_t i = 0; i < num_heads; i++) {
fastGemmPackB(false, head_size, input_hidden_size, weight_data, hidden_size, packed_weight_data, opt);
packed_weight_data += pack_size;
weight_data += head_size;
}
}
// Operator spec: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention
class AttentionLayerImpl CV_FINAL : public AttentionLayer {
public:
AttentionLayerImpl(const LayerParams &params) {
setParamsFrom(params);
CV_CheckTrue(params.has("num_heads"), "DNN/Attention: num_heads is required but missing");
num_heads = params.get<int>("num_heads"); // required, no default value
CV_CheckTrue(params.has("qkv_hidden_sizes"), "DNN/Attention: qkv_hidden_sizes is required but missing");
auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes");
CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "DNN/Attention: qkv_hidden_sizes must and only have three elements");
qkv_hidden_sizes.clear();
qkv_hidden_sizes.resize(3);
qkv_hidden_sizes[0] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(0));
qkv_hidden_sizes[1] = static_cast<size_t>(param_qkv_hidden_sizes.get<int>(1));
/* v_hidden_size needs to be initialized in finalize in case v_slice_end=INT_MAX */
qkv_head_sizes.clear();
qkv_head_sizes.resize(3);
qkv_head_sizes[0] = static_cast<size_t>(qkv_hidden_sizes[0] / num_heads);
qkv_head_sizes[1] = static_cast<size_t>(qkv_hidden_sizes[1] / num_heads);
scale = 1.f / params.get<float>("scale", sqrt(qkv_head_sizes[0]));
output_ndims = params.get<int>("output_ndims", 3);
is_prepacked = false;
}
virtual bool supportBackend(int backendId) CV_OVERRIDE {
return backendId == DNN_BACKEND_OPENCV;
}
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE {
CV_CheckEQ(inputs.size(), static_cast<size_t>(3), "DNN/Attention: three inputs are required");
const auto &input_shape = inputs[0];
const auto &weight_shape = inputs[1];
const auto &bias_shape = inputs[2];
CV_CheckEQ(input_shape.size(), static_cast<size_t>(3), "DNN/Attention: invalid input dimension");
CV_CheckEQ(weight_shape.size(), static_cast<size_t>(2), "DNN/Attention: invalid weight dimension");
CV_CheckEQ(input_shape[2], weight_shape[0], "DNN/Attention: invalid input shape");
CV_CheckEQ(weight_shape[1], bias_shape[0], "DNN/Attention: invalid weight or bias shape");
if (output_ndims == 3) {
outputs.assign(1, inputs[0]);
} else if (output_ndims == 2) {
int batch = input_shape[0], seq_len = input_shape[1], input_hidden_size = input_shape[2];
MatShape output_shape{batch * seq_len, input_hidden_size};
outputs.assign(1, output_shape);
} else {
CV_Error(Error::StsBadArg, format("DNN/Attention: invalid output dimension %zu, valid value is 2 or 3", output_ndims));
}
return false;
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
opt.init();
std::vector<Mat> inputs;
inputs_arr.getMatVector(inputs);
const auto input_shape = shape(inputs[0]);
batch_size = static_cast<size_t>(input_shape[0]);
seq_len = static_cast<size_t>(input_shape[1]);
input_hidden_size = static_cast<size_t>(input_shape[2]);
const auto weight_shape = shape(inputs[1]);
hidden_size = weight_shape[1];
qkv_hidden_sizes[2] = hidden_size - qkv_hidden_sizes[0] - qkv_hidden_sizes[1];
qkv_head_sizes[2] = static_cast<size_t>(qkv_hidden_sizes[2] / num_heads);
}
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
if (inputs_arr.depth() == CV_16S)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
// prepack weights
if (!is_prepacked) {
// prepack
const auto &weight = inputs[1];
const auto *weight_data = weight.ptr<const float>();
packWeight(num_heads, qkv_head_sizes[0], input_hidden_size, weight_data, hidden_size, packed_weight_q, opt);
packWeight(num_heads, qkv_head_sizes[1], input_hidden_size, weight_data + qkv_hidden_sizes[0], hidden_size, packed_weight_k, opt);
packWeight(num_heads, qkv_head_sizes[2], input_hidden_size, weight_data + qkv_hidden_sizes[0] + qkv_hidden_sizes[1], hidden_size, packed_weight_v, opt);
is_prepacked = true;
}
float *packed_weights[3] = {packed_weight_q.data(), packed_weight_k.data(), packed_weight_v.data()};
size_t packed_weights_size[3] = {packed_weight_q.size() / num_heads, packed_weight_k.size() / num_heads, packed_weight_v.size() / num_heads};
Mat gemm_buffer = Mat::zeros(1, int(batch_size * seq_len * hidden_size), CV_32F);
auto *Q = gemm_buffer.ptr<float>();
auto *K = Q + batch_size * seq_len * qkv_hidden_sizes[0];
auto *V = K + batch_size * seq_len * qkv_hidden_sizes[1];
float *QKV[3] = {Q, K, V}; // Q, K, V: [B, N, S, H]
{
const auto &input = inputs[0];
const auto &bias = inputs[2];
const auto *input_data = input.ptr<const float>();
const auto *bias_data = bias.ptr<const float>();
opt.multi_thread = false;
auto fn = [&](const Range &r) {
for (int i = r.start; i < r.end; i++) {
const int batch_index = static_cast<int>((i / 3) / num_heads);
const int head_index = static_cast<int>((i / 3) % num_heads);
const int qkv_index = static_cast<int>(i % 3);
auto *dst = QKV[qkv_index];
size_t head_size = qkv_head_sizes[qkv_index];
int input_offset = batch_index * seq_len * input_hidden_size;
int bias_offset = qkv_index * qkv_hidden_sizes[0] + head_index * head_size;
int dst_offset = (batch_index * num_heads + head_index) * (seq_len * head_size);
// broadcast bias ([NH] -> [BN, SH]) and make copy to dst
const auto *bias_data_src = bias_data + bias_offset;
auto *dst_data = dst + dst_offset;
for (size_t seq_len_idx = 0; seq_len_idx < seq_len; seq_len_idx++) {
std::memcpy(dst_data, bias_data_src, head_size * sizeof(float));
dst_data += head_size;
}
auto *packed_weight = packed_weights[qkv_index] + packed_weights_size[qkv_index] * head_index;
// single-thread gemm kernel
fastGemm(false, seq_len, head_size, input_hidden_size,
1.f, input_data + input_offset, input_hidden_size,
packed_weight, 1.f, dst + dst_offset, head_size, opt);
}
};
size_t loops = 3 * batch_size * num_heads;
double nstripes = loops * seq_len * qkv_head_sizes[0] * input_hidden_size * (1 / 1024.0);
parallel_for_(Range(0, loops), fn, nstripes);
}
// Compute softmax(scale * matmul(Q, K))
std::vector<int> attention_prob_shape{int(batch_size * num_heads), int(seq_len), int(seq_len)};
Mat attention_prob = Mat::zeros(attention_prob_shape.size(), attention_prob_shape.data(), CV_32F);
{
auto *output = attention_prob.ptr<float>();
auto loops = batch_size * num_heads;
auto seq_len_square = seq_len * seq_len;
auto qk_head_size = qkv_head_sizes[0];
auto qk_inner_size = seq_len * qk_head_size;
// Compute scale * matmul(Q, K)
opt.multi_thread = false;
parallel_for_(Range(0, loops), [&] (const Range r) {
for (int i = r.start; i < r.end; i++) {
const int output_offset = i * seq_len_square;
const auto *q = Q + qk_inner_size * i, *k = K + qk_inner_size * i;
fastGemm(false, true, seq_len, qk_head_size, seq_len, qk_head_size,
scale, q, qk_head_size, 1,
k, qk_head_size, 1, 0.f,
output + output_offset, seq_len, opt);
}
}, loops * seq_len * qk_head_size * seq_len * (1 / 1024.0));
// Compute softmax
softmax(attention_prob, attention_prob, attention_prob_shape.size() - 1);
}
// Compute np.matmul(attention_prob, V)
Mat output_buffer = Mat::zeros(1, int(batch_size * num_heads * seq_len * qkv_head_sizes[2]), CV_32F);
{
auto *output = outputs[0].ptr<float>();
auto *output_buff = output_buffer.ptr<float>();
const auto *prob = attention_prob.ptr<const float>();
auto loops = batch_size * num_heads;
auto prob_inner_size = seq_len * seq_len;
auto v_head_size = qkv_head_sizes[2];
auto v_inner_size = seq_len * v_head_size;
opt.multi_thread = false;
parallel_for_(Range(0, loops), [&] (const Range &r) {
for (int i = r.start; i < r.end; i++) {
const int output_offset = i * v_inner_size;
const auto *p = prob + i * prob_inner_size, *v = V + i * v_inner_size;
fastGemm(false, false, seq_len, seq_len, seq_len, v_head_size,
1.f, p, seq_len, 1,
v, v_head_size, 1, 0.f,
output_buff + output_offset, v_head_size, opt);
// tranpose on the fly
const int batch_index = static_cast<int>(i / num_heads);
const int head_index = static_cast<int>(i % num_heads);
auto *src = output_buff + output_offset;
auto *dst = output + (batch_index * seq_len * num_heads + head_index) * v_head_size;
for (int j = 0; j < seq_len; j++) {
std::memcpy(dst, src, v_head_size * sizeof(float));
src += v_head_size;
dst += qkv_hidden_sizes[2];
}
}
}, loops * seq_len * seq_len * v_head_size * (1 / 1024.0));
}
}
private:
size_t num_heads;
std::vector<size_t> qkv_hidden_sizes; // order: {qk_hidden_size, qk_hidden_size, v_hidden_size}
float scale;
size_t output_ndims;
std::vector<size_t> qkv_head_sizes; // order: {qk_head_size, qk_head_size, v_head_size}
size_t batch_size;
size_t seq_len;
size_t input_hidden_size;
size_t hidden_size;
bool is_prepacked;
std::vector<float> packed_weight_q;
std::vector<float> packed_weight_k;
std::vector<float> packed_weight_v;
FastGemmOpt opt;
};
Ptr<AttentionLayer> AttentionLayer::create(const LayerParams &params) {
return makePtr<AttentionLayerImpl>(params);
}
}} // cv::dnn

View File

@ -20,6 +20,32 @@
namespace cv { namespace dnn {
size_t fastGemmPackBSize(size_t N, size_t K, const FastGemmOpt &opt) {
#if CV_TRY_NEON
if (opt.use_neon) {
return static_cast<size_t>(opt_NEON::fastGemmPackBSize(N, K));
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
return static_cast<size_t>(opt_AVX2::fastGemmPackBSize(N, K));
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
return static_cast<size_t>(opt_AVX::fastGemmPackBSize(N, K));
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
return static_cast<size_t>(opt_LASX::fastGemmPackBSize(N, K));
} else
#endif
{
return static_cast<size_t>(cpu_baseline::fastGemmPackBSize(N, K));
}
}
void fastGemmPackB(const Mat &B, std::vector<float> &packed_B, bool trans, FastGemmOpt &opt) {
CV_CheckTypeEQ(B.type(), CV_32F, "fastGemmPackB: only float32 is supported for now");
@ -94,10 +120,45 @@ void fastGemmPackB(const Mat &B, std::vector<float> &packed_B, bool trans, FastG
}
}
void fastGemmPackB(bool trans, size_t N, size_t K, const float *B, size_t ldb, float *packed_B, const FastGemmOpt &opt) {
size_t ldb0 = ldb, ldb1 = 1;
if (trans) {
std::swap(K, N);
std::swap(ldb0, ldb1);
}
const auto &b = (const char *)B;
auto *packed_b = (char *)packed_B;
#if CV_TRY_NEON
if (opt.use_neon) {
opt_NEON::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float));
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
opt_AVX2::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float));
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
opt_AVX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float));
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
opt_LASX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float));
} else
#endif
{
cpu_baseline::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, sizeof(float));
}
}
static void fast_gemm_thin(float alpha, float beta, int M, int N, int K,
const char *a_, int lda0, int lda1,
const char *b_, int ldb,
char *c_, int ldc) {
char *c_, int ldc, bool multi_thread) {
const float* a = (const float*)a_;
auto fn = [&](const Range &r) {
@ -116,16 +177,24 @@ static void fast_gemm_thin(float alpha, float beta, int M, int N, int K,
}
};
int total = M; // outer loops
int cost_per_thread = static_cast<int>(K * N); // inner loops
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
if (multi_thread) {
int total = M; // outer loops
int cost_per_thread = static_cast<int>(K * N); // inner loops
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
} else {
fn(Range(0, M));
}
}
void fastGemm(bool trans_a, int M, int N, int K,
float alpha, const float *A, int lda,
const float *packed_B, float beta,
float *C, int ldc, FastGemmOpt &opt) {
const char *a = (const char *)A;
const char *packed_b = (const char *)packed_B;
char *c = (char *)C;
int lda0 = lda, lda1 = 1;
if (trans_a) {
std::swap(lda0, lda1);
@ -133,26 +202,26 @@ void fastGemm(bool trans_a, int M, int N, int K,
#if CV_TRY_NEON
if (opt.use_neon) {
opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float));
opt_NEON::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float));
opt_AVX2::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float));
opt_AVX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float));
opt_LASX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
{
cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float));
cpu_baseline::fastGemmKernel(M, N, K, alpha, a, lda0, lda1, packed_b, beta, c, ldc, sizeof(float), opt.multi_thread);
}
}
@ -175,36 +244,41 @@ void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb,
}
if (!trans_b && ldb1 == 1 && (M <= 4 || (uint64_t)M * N * K <= 10000)) {
return fast_gemm_thin(alpha, beta, M, N, K, a, lda0, lda1, b, ldb0, c, ldc);
return fast_gemm_thin(alpha, beta, M, N, K, a, lda0, lda1, b, ldb0, c, ldc, opt.multi_thread);
}
#if CV_TRY_NEON
if (opt.use_neon) {
opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1,
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float));
opt_NEON::fastGemmKernel(M, N, K, alpha, a, lda0, lda1,
b, ldb0, ldb1, beta,
c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1,
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float));
opt_AVX2::fastGemmKernel(M, N, K, alpha, a, lda0, lda1,
b, ldb0, ldb1, beta,
c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1,
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float));
opt_AVX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1,
b, ldb0, ldb1, beta,
c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1,
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float));
opt_LASX::fastGemmKernel(M, N, K, alpha, a, lda0, lda1,
b, ldb0, ldb1, beta,
c, ldc, sizeof(float), opt.multi_thread);
} else
#endif
{
cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1,
(const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float));
cpu_baseline::fastGemmKernel(M, N, K, alpha, a, lda0, lda1,
b, ldb0, ldb1, beta,
c, ldc, sizeof(float), opt.multi_thread);
}
}

View File

@ -22,12 +22,14 @@ struct FastGemmOpt {
bool use_avx2;
bool use_neon;
bool use_lasx;
bool multi_thread;
FastGemmOpt() {
use_avx = false;
use_avx2 = false;
use_neon = false;
use_lasx = false;
multi_thread = false;
}
void init() {
@ -35,6 +37,7 @@ struct FastGemmOpt {
use_avx2 = checkHardwareSupport(CPU_AVX2);
use_neon = checkHardwareSupport(CPU_NEON);
use_lasx = checkHardwareSupport(CPU_LASX);
multi_thread = true;
}
bool all() {
@ -148,7 +151,10 @@ struct MatMulHelper {
}
};
size_t fastGemmPackBSize(size_t N, size_t K, const FastGemmOpt &opt);
void fastGemmPackB(const Mat &m, std::vector<float> &packed_B, bool trans, FastGemmOpt &opt);
void fastGemmPackB(bool trans, size_t N, size_t K, const float *B, size_t ldb, float *packed_B, const FastGemmOpt &opt);
void fastGemm(bool trans_a, int M, int N, int K,
float alpha, const float *A, int lda,

View File

@ -83,10 +83,10 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0,
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1,
float beta, char *C, int ldc, int esz);
float beta, char *C, int ldc, int esz, bool multi_thread);
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz);
const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread);
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
@ -179,7 +179,7 @@ static void fast_gemm_macro_kernel(int m, int n, int k,
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1,
float beta, char *C, int ldc, int esz) {
float beta, char *C, int ldc, int esz, bool multi_thread) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
@ -236,15 +236,18 @@ void fastGemmKernel(int M, int N, int K,
}
};
int total = total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
if (multi_thread) {
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total_tiles), fn, nstripes);
} else {
fn(Range(0, total_tiles));
}
}
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz) {
const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
@ -301,10 +304,13 @@ void fastGemmKernel(int M, int N, int K,
}
};
int total = total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
if (multi_thread) {
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total_tiles), fn, nstripes);
} else {
fn(Range(0, total_tiles));
}
}
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,

View File

@ -122,10 +122,10 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0,
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1,
float beta, char *C, int ldc, int esz);
float beta, char *C, int ldc, int esz, bool multi_thread);
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz);
const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread);
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
@ -568,7 +568,7 @@ void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0,
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1,
float beta, char *C, int ldc, int esz) {
float beta, char *C, int ldc, int esz, bool multi_thread) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
@ -646,15 +646,19 @@ void fastGemmKernel(int M, int N, int K,
}
};
int total = total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
if (multi_thread) {
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total_tiles), fn, nstripes);
} else {
fn(Range(0, total_tiles));
}
}
void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz) {
const char *packed_B, float beta, char *C, int ldc, int esz, bool multi_thread) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
@ -722,10 +726,13 @@ void fastGemmKernel(int M, int N, int K,
}
};
int total = total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
if (multi_thread) {
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total_tiles * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total_tiles), fn, nstripes);
} else {
fn(Range(0, total_tiles));
}
}
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,

View File

@ -13,6 +13,7 @@
#include <opencv2/core/utils/logger.hpp>
#include <queue>
#include <limits>
namespace cv { namespace dnn {
CV__DNN_INLINE_NS_BEGIN
@ -181,6 +182,17 @@ static Mat extractConstant(const Ptr<ImportGraphWrapper>& net, int node_id, int
}
}
static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id) {
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1) {
return onnx_net->getNameOfInitializer(initializer_id);
} else {
const auto node = net->getNode(node_id);
return node->getInputName(input_id);
}
}
/* Slice operator has two optional inputs "axes" and "steps". Some models may be set to have
Slice with optional inputs of default values, some of them don't. This Subgraph adjusts
all optional inputs of Slice up to 5.
@ -212,12 +224,308 @@ class AdjustSliceAllOptionalInputsSubgraph : public Subgraph {
node->add_input("");
}
}
private:
private:
int slice_id;
size_t num_inputs_;
};
/* The fusion for the multi-head attention from vision transformer.
Abbreviations:
B - batch_size, symbolic;
S - sequence_length, symbolic;
W - hidden_size, W = N * H;
N - num_heads;
H - head_size;
Graph before fusion:
[Input](BxSxW)
|
LayerNorm
|
Transpose(perm=[1, 0, 2])
|
| (SxBxW)
|
Matmul[Weight(Wx3W)]
|
Add[Bias(3W)]
/ | \
q_Slice k_Slice v_Slice (output(SxBxW))
| | |
q_Reshape k_Reshape v_Reshape (output(Sx(BxN)xH), could be optional if N=1)
| | |
q_Transpose k_Transpose v_Transpose
(1,0,2) (1,2,0) (perm=1,0,2)
|((BxN)xSxH) |((BxN)xHxS) |
q_Div / /
\ / /
qk_MatMul /
| /
qk_Softmax /
| ((BxN)xSxS) / ((BxN)xSxH)
\ /
qkv_MatMul (output((BxN)xSxH))
|
Transpose(perm=1,2,0)
|
Reshape (output(SxH))
|
MatMul
|
Add
|
[Output](BxSxW)
Attributes:
num_heads - number of attention heads
qkv_hidden_sizes - hidden size of qkv respectively, [qk_hidden_size, qk_hidden_size, v_hidden_size],
assume qk_hidden_size = v_hidden_size for now. TODO: support qk_hidden_size != v_hidden_size
scale - scale factor of q, defaults to sqrt(1/num_heads)
Inputs:
weight - merged Q, K, V weights of shape [input_hidden_size, qk_hidden_size + qk_hidden_size + v_hidden_size]
bias - bias of shape [qk_hidden_size + qk_hidden_size + v_hidden_size]
Graph after fusion:
[Input](BxSxW)
|
LayerNorm
|
Transpose
|
Attention[weight, bias]
|
MatMul
|
Add
|
[Output](BxSxW)
More details see See https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention.
*/
class AttentionSubGraph : public Subgraph {
public:
AttentionSubGraph() {
int input = addNodeToMatch("");
int transpose = addNodeToMatch("Transpose", input); // tranpose does not make any differences to the accuracy here in this subgraph
att_matmul = addNodeToMatch("MatMul", transpose, addNodeToMatch(""));
att_add = addNodeToMatch("Add", addNodeToMatch(""), att_matmul);
// v_path
slice_v = addNodeToMatch("Slice", std::vector<int>{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")});
int reshape_v = addNodeToMatch("Reshape", slice_v, addNodeToMatch(""));
int transpose_v = addNodeToMatch("Transpose", reshape_v);
// q_path
slice_q = addNodeToMatch("Slice", std::vector<int>{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")});
reshape_q = addNodeToMatch("Reshape", slice_q, addNodeToMatch(""));
int transpose_q = addNodeToMatch("Transpose", reshape_q);
div_q = addNodeToMatch("Div", transpose_q, addNodeToMatch(""));
// k_path
slice_k = addNodeToMatch("Slice", std::vector<int>{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")});
int reshape_k = addNodeToMatch("Reshape", slice_k, addNodeToMatch(""));
int transpose_k = addNodeToMatch("Transpose", reshape_k);
// qk
int matmul_qk = addNodeToMatch("MatMul", div_q, transpose_k);
int softmax_qk = addNodeToMatch("Softmax", matmul_qk);
// qkv
int matmul_qkv = addNodeToMatch("MatMul", softmax_qk, transpose_v);
int transpose_qkv = addNodeToMatch("Transpose", matmul_qkv);
last_reshape = addNodeToMatch("Reshape", transpose_qkv, addNodeToMatch(""));
setFusedNode("Attention", input);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
// get attrs - qkv_hidden_sizes
qkv_hidden_sizes.clear();
auto fill_qkv_hidden_sizes = [&] (const int slice_node_id) {
int slice_start = extractConstant(net, matchedNodesIds[slice_node_id], 1).at<int>(0);
int slice_end = extractConstant(net, matchedNodesIds[slice_node_id], 2).at<int>(0);
if (slice_end == std::numeric_limits<int>::max()) {
qkv_hidden_sizes.push_back(0); // workaround for Slice with end=INT_MAX
} else {
int64_t hidden_size = static_cast<int64_t>(slice_end - slice_start);
qkv_hidden_sizes.push_back(hidden_size);
}
};
fill_qkv_hidden_sizes(slice_q);
fill_qkv_hidden_sizes(slice_k);
fill_qkv_hidden_sizes(slice_v); // TODO: take care of INT64_MAX
CV_CheckEQ(qkv_hidden_sizes.size(), static_cast<size_t>(3), "ONNXSimplifier/Attention: invalid qkv hidden sizes");
CV_CheckEQ(int(qkv_hidden_sizes[0]), int(qkv_hidden_sizes[1]), "ONNXSimplifier/Attention: invalid qkv hidden sizes, q_hidden_size == v_hidden_size is required");
// get attrs - num_heads, scale
num_heads = extractConstant(net, matchedNodesIds[reshape_q], 1).at<int>(1);
scale = extractConstant(net, matchedNodesIds[div_q], 1).at<float>(0);
output_ndims = extractConstant(net, matchedNodesIds[last_reshape], 1).size[0];
// get names
weight_name = getInputName(net, matchedNodesIds[att_matmul], 1);
bias_name = getInputName(net, matchedNodesIds[att_add], 0);
return true;
}
return false;
}
virtual void finalize(const Ptr<ImportGraphWrapper>& net,
const Ptr<ImportNodeWrapper>& fusedNode,
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE {
// add attrs
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::AttributeProto* attr_num_heads = node->add_attribute();
attr_num_heads->set_name("num_heads");
attr_num_heads->set_i(num_heads);
opencv_onnx::AttributeProto* attr_qkv_hidden_sizes = node->add_attribute();
attr_qkv_hidden_sizes->set_name("qkv_hidden_sizes");
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[0]);
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[1]);
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[2]);
opencv_onnx::AttributeProto* attr_scale = node->add_attribute();
attr_scale->set_name("scale");
attr_scale->set_f(scale);
// add customized attrs
opencv_onnx::AttributeProto* attr_output_ndims = node->add_attribute();
attr_output_ndims->set_name("output_ndims");
attr_output_ndims->set_i(output_ndims);
// add inputs
node->add_input(weight_name);
node->add_input(bias_name);
}
private:
int att_matmul, att_add;
int slice_q, slice_k, slice_v;
int reshape_q, div_q, last_reshape;
std::vector<int64_t> qkv_hidden_sizes; // order: [qk_hidden_size, qk_hidden_size, v_hidden_size]
int64_t num_heads;
float scale;
int64_t output_ndims;
std::string weight_name;
std::string bias_name;
};
/* Attention subgraph with single head.
No Reshape operator is appended after each Slice operator.
*/
class AttentionSingleHeadSubGraph : public Subgraph {
public:
AttentionSingleHeadSubGraph() {
int input = addNodeToMatch("");
int transpose = addNodeToMatch("Transpose", input); // tranpose does not make any differences to the accuracy here in this subgraph
att_matmul = addNodeToMatch("MatMul", transpose, addNodeToMatch(""));
att_add = addNodeToMatch("Add", addNodeToMatch(""), att_matmul);
// v_path
slice_v = addNodeToMatch("Slice", std::vector<int>{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")});
int transpose_v = addNodeToMatch("Transpose", slice_v);
// q_path
slice_q = addNodeToMatch("Slice", std::vector<int>{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")});
int transpose_q = addNodeToMatch("Transpose", slice_q);
div_q = addNodeToMatch("Div", transpose_q, addNodeToMatch(""));
// k_path
slice_k = addNodeToMatch("Slice", std::vector<int>{att_add, addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch(""), addNodeToMatch("")});
int transpose_k = addNodeToMatch("Transpose", slice_k);
// qk
int matmul_qk = addNodeToMatch("MatMul", div_q, transpose_k);
int softmax_qk = addNodeToMatch("Softmax", matmul_qk);
// qkv
int matmul_qkv = addNodeToMatch("MatMul", softmax_qk, transpose_v);
int transpose_qkv = addNodeToMatch("Transpose", matmul_qkv);
last_reshape = addNodeToMatch("Reshape", transpose_qkv, addNodeToMatch(""));
setFusedNode("Attention", input);
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE {
if (Subgraph::match(net, nodeId, matchedNodesIds)) {
// get attrs - qkv_hidden_sizes
qkv_hidden_sizes.clear();
auto fill_qkv_hidden_sizes = [&] (const int slice_node_id) {
int slice_start = extractConstant(net, matchedNodesIds[slice_node_id], 1).at<int>(0);
int slice_end = extractConstant(net, matchedNodesIds[slice_node_id], 2).at<int>(0);
if (slice_end == std::numeric_limits<int>::max()) {
qkv_hidden_sizes.push_back(0); // workaround for Slice with end=INT_MAX
} else {
int64_t hidden_size = static_cast<int64_t>(slice_end - slice_start);
qkv_hidden_sizes.push_back(hidden_size);
}
};
fill_qkv_hidden_sizes(slice_q);
fill_qkv_hidden_sizes(slice_k);
fill_qkv_hidden_sizes(slice_v);
CV_CheckEQ(qkv_hidden_sizes.size(), static_cast<size_t>(3), "ONNXSimplifier/Attention: invalid qkv hidden sizes");
CV_CheckEQ(int(qkv_hidden_sizes[0]), int(qkv_hidden_sizes[1]), "ONNXSimplifier/Attention: invalid qkv hidden sizes, q_hidden_size == v_hidden_size is required");
// get attrs - num_heads, scale
num_heads = 1;
scale = extractConstant(net, matchedNodesIds[div_q], 1).at<float>(0);
output_ndims = extractConstant(net, matchedNodesIds[last_reshape], 1).size[0];
// get names
weight_name = getInputName(net, matchedNodesIds[att_matmul], 1);
bias_name = getInputName(net, matchedNodesIds[att_add], 0);
return true;
}
return false;
}
virtual void finalize(const Ptr<ImportGraphWrapper>& net,
const Ptr<ImportNodeWrapper>& fusedNode,
std::vector<Ptr<ImportNodeWrapper> >&) CV_OVERRIDE {
// add attrs
opencv_onnx::NodeProto* node = fusedNode.dynamicCast<ONNXNodeWrapper>()->node;
opencv_onnx::AttributeProto* attr_num_heads = node->add_attribute();
attr_num_heads->set_name("num_heads");
attr_num_heads->set_i(num_heads);
opencv_onnx::AttributeProto* attr_qkv_hidden_sizes = node->add_attribute();
attr_qkv_hidden_sizes->set_name("qkv_hidden_sizes");
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[0]);
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[1]);
attr_qkv_hidden_sizes->add_ints(qkv_hidden_sizes[2]);
opencv_onnx::AttributeProto* attr_scale = node->add_attribute();
attr_scale->set_name("scale");
attr_scale->set_f(scale);
// add customized attrs
opencv_onnx::AttributeProto* attr_output_ndims = node->add_attribute();
attr_output_ndims->set_name("output_ndims");
attr_output_ndims->set_i(output_ndims);
// add inputs
node->add_input(weight_name);
node->add_input(bias_name);
}
protected:
int att_matmul, att_add;
int slice_q, slice_k, slice_v;
int div_q, last_reshape;
std::vector<int64_t> qkv_hidden_sizes; // order: [qk_hidden_size, qk_hidden_size, v_hidden_size]
int64_t num_heads;
float scale;
int64_t output_ndims;
std::string weight_name;
std::string bias_name;
};
/* Fusion for Gelu.
Graph before fusion:
@ -390,21 +698,6 @@ public:
return axis_;
}
static std::string getInputName(const Ptr<ImportGraphWrapper>& net, int node_id, int input_id)
{
auto onnx_net = net.dynamicCast<ONNXGraphWrapper>();
int initializer_id = onnx_net->getInputInitializerId(node_id, input_id);
if (initializer_id != -1)
{
return onnx_net->getNameOfInitializer(initializer_id);
}
else
{
const auto node = net->getNode(node_id);
return node->getInputName(input_id);
}
}
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
std::vector<int>& matchedNodesIds) CV_OVERRIDE
{
@ -1252,6 +1545,10 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
subgraphs.push_back(makePtr<MishSubgraph>());
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
if (getParam_DNN_BACKEND_DEFAULT() == DNN_BACKEND_OPENCV) {
subgraphs.push_back(makePtr<AttentionSubGraph>());
subgraphs.push_back(makePtr<AttentionSingleHeadSubGraph>());
}
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
}

View File

@ -207,6 +207,7 @@ private:
void parseQConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseQGemm (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseQSoftmax (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseAttention (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
// '???' domain or '???' layer type
void parseCustomLayer (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@ -3894,6 +3895,31 @@ void ONNXImporter::parseQSoftmax(LayerParams& layerParams, const opencv_onnx::No
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseAttention(LayerParams& params, const opencv_onnx::NodeProto& node_proto) {
CV_CheckTrue(params.has("num_heads"), "ONNXImporter/parseAttention: num_heads is required but missing");
CV_CheckTrue(params.has("qkv_hidden_sizes"), "ONNXImporter/parseAttention: qkv_hidden_sizes is required but missing");
auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes");
CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "ONNXImporter/parseAttention: qkv_hidden_sizes is must and only have three elements");
for (size_t i = 1; i < node_proto.input_size(); i++) {
if (layer_id.find(node_proto.input(i)) == layer_id.end()) {
Mat tensor = getBlob(node_proto, i);
LayerParams const_params;
const_params.name = node_proto.input(i);
const_params.type = "Const";
const_params.blobs.push_back(tensor);
opencv_onnx::NodeProto proto;
proto.add_output(const_params.name);
addLayer(const_params, proto);
}
}
addLayer(params, node_proto);
}
// Domain: ai.onnx (default)
// URL: https://github.com/onnx/onnx/blob/master/docs/Operators.md
void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
@ -3977,6 +4003,11 @@ void ONNXImporter::buildDispatchMap_ONNX_AI(int opset_version)
dispatch["QLinearConv"] = &ONNXImporter::parseQConv;
dispatch["QLinearMatMul"] = &ONNXImporter::parseQMatMul;
// com.microsft: This operator is added for compatibility via onnx graph simplifier.
// Opset domain cannot be modified from onnx_graph_simplifier.cpp so this
// operator cannot be parsed if only added in buildDispatchMap_COM_MICROSOFT
dispatch["Attention"] = &ONNXImporter::parseAttention;
domain_dispatch_map[str_domain_ai_onnx] = dispatch;
}
@ -3994,6 +4025,7 @@ void ONNXImporter::buildDispatchMap_COM_MICROSOFT(int opset_version)
dispatch["QLinearConcat"] = &ONNXImporter::parseQConcat;
dispatch["QGemm"] = &ONNXImporter::parseQGemm;
dispatch["QLinearSoftmax"] = &ONNXImporter::parseQSoftmax;
dispatch["Attention"] = &ONNXImporter::parseAttention;
domain_dispatch_map["com.microsoft"] = dispatch;
}

View File

@ -130,4 +130,13 @@ TEST_F(Test_Graph_Simplifier, MishSubgraph) {
test("mish", "Mish");
}
TEST_F(Test_Graph_Simplifier, AttentionSubgraph) {
/* Test for 2 subgraphs
- AttentionSubgraph
- AttentionSingleHeadSubgraph
*/
test("attention", "Attention");
test("attention_single_head", "Attention");
}
}}

View File

@ -2949,6 +2949,63 @@ TEST_P(Test_ONNX_layers, Expand_shape_model4) {
testONNXModels("test_expand_shape_model4", pb, 0, 0, false, true, 1);
}
TEST_P(Test_ONNX_layers, Attention) {
testONNXModels("attention");
}
TEST_P(Test_ONNX_layers, AttentionSingleHead) {
testONNXModels("attention_single_head");
}
TEST_P(Test_ONNX_nets, ViT_B_32) {
applyTestTag(CV_TEST_TAG_LONG, CV_TEST_TAG_DEBUG_LONG);
if (backend == DNN_BACKEND_CUDA && target == DNN_TARGET_CUDA_FP16)
{
// does not pass test for now
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA_FP16);
}
const std::string model_path = _tf("models/vit_b_32.onnx", false);
auto net = readNet(model_path);
ASSERT_FALSE(net.empty());
net.setPreferableBackend(backend);
net.setPreferableTarget(target);
auto image = imread(_tf("../googlenet_0.png"));
auto blob = blobFromImage(image, 1.f, Size(224, 224));
auto ref = blobFromNPY(_tf("data/output_vit_b_32.npy"));
checkBackend(&blob, &ref);
net.setInput(blob);
auto out = net.forward();
normAssert(ref, out, "ViTB_32", default_l1, default_lInf);
}
TEST_P(Test_ONNX_nets, VitTrack) {
auto image = imread(_tf("../dog_orig_size.png"));
auto input0 = blobFromImage(image, 1.f, Size(128, 128));
auto input1 = blobFromImage(image, 1.f, Size(256, 256));
auto net = readNet(_tf("models/object_tracking_vittrack_2023sep.onnx", false));
net.setInput(input0, "template");
net.setInput(input1, "search");
std::vector<std::string> output_names{"output1", "output2", "output3"};
std::vector<Mat> outputs;
net.forward(outputs, output_names);
auto ref_output1 = blobFromNPY(_tf("data/output_object_tracking_vittrack_2023sep_0.npy"));
auto ref_output2 = blobFromNPY(_tf("data/output_object_tracking_vittrack_2023sep_1.npy"));
auto ref_output3 = blobFromNPY(_tf("data/output_object_tracking_vittrack_2023sep_2.npy"));
normAssert(ref_output1, outputs[0], "VitTrack output1");
normAssert(ref_output2, outputs[1], "VitTrack output2");
normAssert(ref_output3, outputs[2], "VitTrack output3");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
}} // namespace