// This file is part of OpenCV project. // It is subject to the license terms in the LICENSE file found in the top-level directory // of this distribution and at http://opencv.org/license.html. #include "../precomp.hpp" #include "cpu_kernels/fast_gemm.hpp" #include "cpu_kernels/softmax.hpp" #include namespace cv { namespace dnn { static void packWeight(size_t num_heads, size_t head_size, size_t input_hidden_size, const float *weight_data, size_t hidden_size, std::vector &packed_weight, const FastGemmOpt &opt) { // num_heads * pack(head_size, input_hidden_size) size_t pack_size = fastGemmPackBSize(head_size, input_hidden_size, opt); size_t packed_weight_size = num_heads * pack_size; packed_weight.resize(packed_weight_size, 0.f); auto *packed_weight_data = packed_weight.data(); for (size_t i = 0; i < num_heads; i++) { fastGemmPackB(false, head_size, input_hidden_size, weight_data, hidden_size, packed_weight_data, opt); packed_weight_data += pack_size; weight_data += head_size; } } // Operator spec: https://github.com/microsoft/onnxruntime/blob/v1.16.1/docs/ContribOperators.md#com.microsoft.Attention class AttentionLayerImpl CV_FINAL : public AttentionLayer { public: AttentionLayerImpl(const LayerParams ¶ms) { setParamsFrom(params); CV_CheckTrue(params.has("num_heads"), "DNN/Attention: num_heads is required but missing"); num_heads = params.get("num_heads"); // required, no default value CV_CheckTrue(params.has("qkv_hidden_sizes"), "DNN/Attention: qkv_hidden_sizes is required but missing"); auto param_qkv_hidden_sizes = params.get("qkv_hidden_sizes"); CV_CheckEQ(param_qkv_hidden_sizes.size(), 3, "DNN/Attention: qkv_hidden_sizes must and only have three elements"); qkv_hidden_sizes.clear(); qkv_hidden_sizes.resize(3); qkv_hidden_sizes[0] = static_cast(param_qkv_hidden_sizes.get(0)); qkv_hidden_sizes[1] = static_cast(param_qkv_hidden_sizes.get(1)); /* v_hidden_size needs to be initialized in finalize in case v_slice_end=INT_MAX */ qkv_head_sizes.clear(); qkv_head_sizes.resize(3); qkv_head_sizes[0] = static_cast(qkv_hidden_sizes[0] / num_heads); qkv_head_sizes[1] = static_cast(qkv_hidden_sizes[1] / num_heads); scale = 1.f / params.get("scale", sqrt(qkv_head_sizes[0])); output_ndims = params.get("output_ndims", 3); is_prepacked = false; } virtual bool supportBackend(int backendId) CV_OVERRIDE { return backendId == DNN_BACKEND_OPENCV; } virtual bool getMemoryShapes(const std::vector &inputs, const int requiredOutputs, std::vector &outputs, std::vector &internals) const CV_OVERRIDE { CV_CheckEQ(inputs.size(), static_cast(3), "DNN/Attention: three inputs are required"); const auto &input_shape = inputs[0]; const auto &weight_shape = inputs[1]; const auto &bias_shape = inputs[2]; CV_CheckEQ(input_shape.size(), static_cast(3), "DNN/Attention: invalid input dimension"); CV_CheckEQ(weight_shape.size(), static_cast(2), "DNN/Attention: invalid weight dimension"); CV_CheckEQ(input_shape[2], weight_shape[0], "DNN/Attention: invalid input shape"); CV_CheckEQ(weight_shape[1], bias_shape[0], "DNN/Attention: invalid weight or bias shape"); if (output_ndims == 3) { outputs.assign(1, inputs[0]); } else if (output_ndims == 2) { int batch = input_shape[0], seq_len = input_shape[1], input_hidden_size = input_shape[2]; MatShape output_shape{batch * seq_len, input_hidden_size}; outputs.assign(1, output_shape); } else { CV_Error(Error::StsBadArg, format("DNN/Attention: invalid output dimension %zu, valid value is 2 or 3", output_ndims)); } return false; } virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { opt.init(); std::vector inputs; inputs_arr.getMatVector(inputs); const auto input_shape = shape(inputs[0]); batch_size = static_cast(input_shape[0]); seq_len = static_cast(input_shape[1]); input_hidden_size = static_cast(input_shape[2]); const auto weight_shape = shape(inputs[1]); hidden_size = weight_shape[1]; qkv_hidden_sizes[2] = hidden_size - qkv_hidden_sizes[0] - qkv_hidden_sizes[1]; qkv_head_sizes[2] = static_cast(qkv_hidden_sizes[2] / num_heads); } void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { CV_TRACE_FUNCTION(); CV_TRACE_ARG_VALUE(name, "name", name.c_str()); if (inputs_arr.depth() == CV_16F) { forward_fallback(inputs_arr, outputs_arr, internals_arr); return; } std::vector inputs, outputs; inputs_arr.getMatVector(inputs); outputs_arr.getMatVector(outputs); // prepack weights if (!is_prepacked) { // prepack const auto &weight = inputs[1]; const auto *weight_data = weight.ptr(); packWeight(num_heads, qkv_head_sizes[0], input_hidden_size, weight_data, hidden_size, packed_weight_q, opt); packWeight(num_heads, qkv_head_sizes[1], input_hidden_size, weight_data + qkv_hidden_sizes[0], hidden_size, packed_weight_k, opt); packWeight(num_heads, qkv_head_sizes[2], input_hidden_size, weight_data + qkv_hidden_sizes[0] + qkv_hidden_sizes[1], hidden_size, packed_weight_v, opt); is_prepacked = true; } float *packed_weights[3] = {packed_weight_q.data(), packed_weight_k.data(), packed_weight_v.data()}; size_t packed_weights_size[3] = {packed_weight_q.size() / num_heads, packed_weight_k.size() / num_heads, packed_weight_v.size() / num_heads}; Mat gemm_buffer = Mat::zeros(1, int(batch_size * seq_len * hidden_size), CV_32F); auto *Q = gemm_buffer.ptr(); auto *K = Q + batch_size * seq_len * qkv_hidden_sizes[0]; auto *V = K + batch_size * seq_len * qkv_hidden_sizes[1]; float *QKV[3] = {Q, K, V}; // Q, K, V: [B, N, S, H] { const auto &input = inputs[0]; const auto &bias = inputs[2]; const auto *input_data = input.ptr(); const auto *bias_data = bias.ptr(); opt.multi_thread = false; auto fn = [&](const Range &r) { for (int i = r.start; i < r.end; i++) { const int batch_index = static_cast((i / 3) / num_heads); const int head_index = static_cast((i / 3) % num_heads); const int qkv_index = static_cast(i % 3); auto *dst = QKV[qkv_index]; size_t head_size = qkv_head_sizes[qkv_index]; int input_offset = batch_index * seq_len * input_hidden_size; int bias_offset = qkv_index * qkv_hidden_sizes[0] + head_index * head_size; int dst_offset = (batch_index * num_heads + head_index) * (seq_len * head_size); // broadcast bias ([NH] -> [BN, SH]) and make copy to dst const auto *bias_data_src = bias_data + bias_offset; auto *dst_data = dst + dst_offset; for (size_t seq_len_idx = 0; seq_len_idx < seq_len; seq_len_idx++) { std::memcpy(dst_data, bias_data_src, head_size * sizeof(float)); dst_data += head_size; } auto *packed_weight = packed_weights[qkv_index] + packed_weights_size[qkv_index] * head_index; // single-thread gemm kernel fastGemm(false, seq_len, head_size, input_hidden_size, 1.f, input_data + input_offset, input_hidden_size, packed_weight, 1.f, dst + dst_offset, head_size, opt); } }; size_t loops = 3 * batch_size * num_heads; double nstripes = loops * seq_len * qkv_head_sizes[0] * input_hidden_size * (1 / 1024.0); parallel_for_(Range(0, loops), fn, nstripes); } // Compute softmax(scale * matmul(Q, K)) std::vector attention_prob_shape{int(batch_size * num_heads), int(seq_len), int(seq_len)}; Mat attention_prob = Mat::zeros(attention_prob_shape.size(), attention_prob_shape.data(), CV_32F); { auto *output = attention_prob.ptr(); auto loops = batch_size * num_heads; auto seq_len_square = seq_len * seq_len; auto qk_head_size = qkv_head_sizes[0]; auto qk_inner_size = seq_len * qk_head_size; // Compute scale * matmul(Q, K) opt.multi_thread = false; parallel_for_(Range(0, loops), [&] (const Range r) { for (int i = r.start; i < r.end; i++) { const int output_offset = i * seq_len_square; const auto *q = Q + qk_inner_size * i, *k = K + qk_inner_size * i; fastGemm(false, true, seq_len, qk_head_size, seq_len, qk_head_size, scale, q, qk_head_size, 1, k, qk_head_size, 1, 0.f, output + output_offset, seq_len, opt); } }, loops * seq_len * qk_head_size * seq_len * (1 / 1024.0)); // Compute softmax softmax(attention_prob, attention_prob, attention_prob_shape.size() - 1); } // Compute np.matmul(attention_prob, V) Mat output_buffer = Mat::zeros(1, int(batch_size * num_heads * seq_len * qkv_head_sizes[2]), CV_32F); { auto *output = outputs[0].ptr(); auto *output_buff = output_buffer.ptr(); const auto *prob = attention_prob.ptr(); auto loops = batch_size * num_heads; auto prob_inner_size = seq_len * seq_len; auto v_head_size = qkv_head_sizes[2]; auto v_inner_size = seq_len * v_head_size; opt.multi_thread = false; parallel_for_(Range(0, loops), [&] (const Range &r) { for (int i = r.start; i < r.end; i++) { const int output_offset = i * v_inner_size; const auto *p = prob + i * prob_inner_size, *v = V + i * v_inner_size; fastGemm(false, false, seq_len, seq_len, seq_len, v_head_size, 1.f, p, seq_len, 1, v, v_head_size, 1, 0.f, output_buff + output_offset, v_head_size, opt); // tranpose on the fly const int batch_index = static_cast(i / num_heads); const int head_index = static_cast(i % num_heads); auto *src = output_buff + output_offset; auto *dst = output + (batch_index * seq_len * num_heads + head_index) * v_head_size; for (int j = 0; j < seq_len; j++) { std::memcpy(dst, src, v_head_size * sizeof(float)); src += v_head_size; dst += qkv_hidden_sizes[2]; } } }, loops * seq_len * seq_len * v_head_size * (1 / 1024.0)); } } private: size_t num_heads; std::vector qkv_hidden_sizes; // order: {qk_hidden_size, qk_hidden_size, v_hidden_size} float scale; size_t output_ndims; std::vector qkv_head_sizes; // order: {qk_head_size, qk_head_size, v_head_size} size_t batch_size; size_t seq_len; size_t input_hidden_size; size_t hidden_size; bool is_prepacked; std::vector packed_weight_q; std::vector packed_weight_k; std::vector packed_weight_v; FastGemmOpt opt; }; Ptr AttentionLayer::create(const LayerParams ¶ms) { return makePtr(params); } }} // cv::dnn