diff --git a/modules/dnn/CMakeLists.txt b/modules/dnn/CMakeLists.txt index 774e3c7b5a..5963eb68d3 100644 --- a/modules/dnn/CMakeLists.txt +++ b/modules/dnn/CMakeLists.txt @@ -9,6 +9,7 @@ ocv_add_dispatched_file_force_all("int8layers/layers_common" AVX2 AVX512_SKX LAS ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_block" AVX AVX2) ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_depthwise" AVX AVX2 RVV LASX) ocv_add_dispatched_file_force_all("layers/cpu_kernels/conv_winograd_f63" AVX AVX2) +ocv_add_dispatched_file_force_all("layers/cpu_kernels/fast_gemm_kernels" AVX AVX2 NEON LASX) ocv_add_module(dnn opencv_core opencv_imgproc WRAP python java objc js) diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index e133ffea65..d92e060fff 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -1101,6 +1101,16 @@ CV__DNN_INLINE_NS_BEGIN static Ptr create(const LayerParams& params); }; + class CV_EXPORTS GemmLayer : public Layer { + public: + bool trans_a; + bool trans_b; + float alpha; + float beta; + + static Ptr create(const LayerParams& params); + }; + //! @} //! @} CV__DNN_INLINE_NS_END diff --git a/modules/dnn/perf/perf_gemm.cpp b/modules/dnn/perf/perf_gemm.cpp new file mode 100644 index 0000000000..fff2872b82 --- /dev/null +++ b/modules/dnn/perf/perf_gemm.cpp @@ -0,0 +1,251 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "perf_precomp.hpp" +#include + +namespace opencv_test { + +struct GemmParam_t { + std::vector a_shape; + std::vector b_shape; + std::vector c_shape; + bool trans_a; + bool trans_b; + + GemmParam_t(std::vector a_shape_, std::vector b_shape_, std::vector c_shape_ = {}, bool trans_a_ = false, bool trans_b_ = false) + : a_shape(a_shape_), b_shape(b_shape_), c_shape(c_shape_), trans_a(trans_a_), trans_b(trans_b_) {} +}; + +// TODO: Dsiable most of the test cases except vision transformers to save time +static const GemmParam_t test_gemm_configs[] = { + // vision transformers cases + { { 768, 768 }, { 768, 768 }, { 768 } }, + { { 1024, 1024 }, { 1024, 1024 }, { 1024 } }, + { { 50, 768 }, { 768, 2304 } }, + { { 197, 768 }, { 768, 2304 } }, + { { 50, 1024 }, { 1024, 3072 } }, + { { 197, 1024 }, { 1024, 3072 } }, + +// these cases are commented to save testing time +/* + // square mat + { { 64, 64 }, { 64, 64 } }, + { { 128, 128 }, { 128, 128 } }, + { { 256, 256 }, { 256, 256 } }, + { { 512, 512 }, { 512, 512 } }, + { { 1024, 1024 }, { 1024, 1024 } }, + { { 4096, 4096 }, { 4096, 4096 } }, + + // retangular mat + { { 256, 256 }, { 256, 1024 } }, + { { 256, 1024 }, { 1024, 256 } }, + { { 256, 1024 }, { 1024, 1024 } }, + { { 1024, 1024 }, { 1024, 256 } }, + { { 1024, 256 }, { 256, 1024 } }, + { { 1024, 256 }, { 256, 256 } }, + + // with C + { { 256, 256 }, { 256, 256 }, { 256 } }, + { { 256, 256 }, { 256, 1024 }, { 1024 } }, + { { 256, 1024 }, { 1024, 256 }, { 256 } }, + { { 256, 1024 }, { 1024, 1024 }, { 1024 } }, + { { 1024, 1024 }, { 1024, 256 }, { 256 } }, + { { 1024, 256 }, { 256, 1024 }, { 1024 } }, + { { 1024, 256 }, { 256, 256 }, { 256 } }, + + // with C and trans_b + { { 256, 256 }, { 256, 256 }, { 256 } , false, true}, + { { 256, 1024 }, { 256, 1024 }, { 256 } , false, true}, + { { 256, 1024 }, { 1024, 1024 }, { 1024 } , false, true}, + { { 1024, 1024 }, { 1024, 1024 }, { 1024 } , false, true}, + { { 1024, 256 }, { 1024, 256 }, { 1024 } , false, true}, + { { 1024, 256 }, { 256, 256 }, { 256 } , false, true}, + + // with C and trans_b and trans_a + { { 256, 256 }, { 256, 256 }, { 256 } , true, true}, + { { 1024, 256 }, { 256, 1024 }, { 256 } , true, true}, + { { 256, 1024 }, { 1024, 256 }, { 1024 } , true, true}, + { { 1024, 1024 }, { 1024, 1024 }, { 1024 } , true, true}, +*/ +}; + +struct GemmParamId +{ + enum { + GEMM_0 = 0, + GEMM_LAST = sizeof(test_gemm_configs) / sizeof(test_gemm_configs[0]) + }; + int val_; + GemmParamId(int val = 0) : val_(val) {} + operator int() const { return val_; } + static ::testing::internal::ParamGenerator all() + { + enum { NUM = (int)GEMM_LAST }; + GemmParamId v_[NUM]; for (int i = 0; i < NUM; ++i) { v_[i] = GemmParamId(i); } // reduce generated code size + return ::testing::ValuesIn(v_, v_ + NUM); + } +}; + +static inline void PrintTo(const GemmParamId& v, std::ostream* os) +{ + CV_Assert((int)v >= 0); CV_Assert((int)v < GemmParamId::GEMM_LAST); + const GemmParam_t& p = test_gemm_configs[(int)v]; + + auto print_shape = [os](const std::vector& shape, const std::string tag) { + if (shape.empty()) { + return ; + } + + *os << tag << "=["; + for (size_t i = 0; i < shape.size(); ++i) { + if (i == shape.size() - 1) { + *os << shape[i] << "]"; + break; + } + *os << shape[i] << ", "; + } + }; + + print_shape(p.a_shape, "A"); + print_shape(p.b_shape, ", B"); + print_shape(p.c_shape, ", C"); + *os << ", trans_a=" << p.trans_a << ", trans_b=" << p.trans_b; +} + +typedef tuple > GemmTestParam_t; +typedef TestBaseWithParam Gemm; + +PERF_TEST_P_(Gemm, gemm) +{ + int test_id = (int)get<0>(GetParam()); + ASSERT_GE(test_id, 0); ASSERT_LT(test_id, GemmParamId::GEMM_LAST); + const GemmParam_t& params = test_gemm_configs[test_id]; + auto a_shape = params.a_shape; + auto b_shape = params.b_shape; + auto c_shape = params.c_shape; + auto trans_a = params.trans_a; + auto trans_b = params.trans_b; + float alpha = 1.f; + float beta = 1.f; + + Backend backend_id = get<0>(get<1>(GetParam())); + Target target_id = get<1>(get<1>(GetParam())); + + bool have_bias = c_shape.empty() ? false : true; + + Mat A(static_cast(a_shape.size()), a_shape.data(), CV_32F); + randu(A, -1.0f, 1.0f); + Mat B(static_cast(b_shape.size()), b_shape.data(), CV_32F); + randu(A, -1.0f, 1.0f); + + LayerParams lp; + lp.type = "Gemm"; + lp.name = "testLayer"; + lp.set("transA", trans_a); + lp.set("transB", trans_b); + lp.set("alpha", alpha); + lp.set("beta", beta); + lp.set("real_ndims_C", static_cast(c_shape.size())); + + lp.set("constB", true); + lp.blobs.push_back(B); + if (have_bias) { + Mat C(static_cast(c_shape.size()), c_shape.data(), CV_32F); + randu(C, -1.0f, 1.0f); + lp.set("have_bias", true); + lp.set("constC", true); + lp.blobs.push_back(C); + } + + Net net; + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.setPreferableBackend(backend_id); + net.setPreferableTarget(target_id); + + // warmup + { + net.setInput(A); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); +} + +PERF_TEST_P_(Gemm, innerproduct) +{ + int test_id = (int)get<0>(GetParam()); + ASSERT_GE(test_id, 0); ASSERT_LT(test_id, GemmParamId::GEMM_LAST); + const GemmParam_t& params = test_gemm_configs[test_id]; + auto a_shape = params.a_shape; + auto b_shape = params.b_shape; + auto c_shape = params.c_shape; + auto trans_a = params.trans_a; + auto trans_b = params.trans_b; + + Backend backend_id = get<0>(get<1>(GetParam())); + Target target_id = get<1>(get<1>(GetParam())); + + bool have_bias = c_shape.empty() ? false : true; + + Mat A(static_cast(a_shape.size()), a_shape.data(), CV_32F); + randu(A, -1.0f, 1.0f); + Mat B(static_cast(b_shape.size()), b_shape.data(), CV_32F); + randu(A, -1.0f, 1.0f); + + LayerParams lp; + lp.type = "InnerProduct"; + lp.name = "testLayer"; + if (trans_a) { + cv::transpose(A, A); + } + if (!trans_b) { + cv::transpose(B, B); + } + lp.blobs.push_back(B); + lp.set("num_output", B.size[0]); + if (have_bias) { + Mat C(static_cast(c_shape.size()), c_shape.data(), CV_32F); + randu(C, -1.0f, 1.0f); + lp.blobs.push_back(C); + lp.set("bias_term", true); + } else { + lp.set("bias_term", false); + } + + Net net; + int id = net.addLayerToPrev(lp.name, lp.type, lp); + net.connect(0, 0, id, 0); + net.setPreferableBackend(backend_id); + net.setPreferableTarget(target_id); + + // warmup + { + std::vector input_names(2); + input_names[0] = "A"; + net.setInputsNames(input_names); + net.setInput(A, input_names[0]); + Mat out = net.forward(); + } + + TEST_CYCLE() + { + Mat res = net.forward(); + } + + SANITY_CHECK_NOTHING(); +} + +INSTANTIATE_TEST_CASE_P(/**/, Gemm, Combine( + GemmParamId::all(), + dnnBackendsAndTargets(false, false) // defined in ../test/test_common.hpp +)); + +} // namespace \ No newline at end of file diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index 2ce54ac0bb..8786ce484d 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -101,6 +101,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Reduce, ReduceLayer); CV_DNN_REGISTER_LAYER_CLASS(LRN, LRNLayer); CV_DNN_REGISTER_LAYER_CLASS(InnerProduct, InnerProductLayer); + CV_DNN_REGISTER_LAYER_CLASS(Gemm, GemmLayer); CV_DNN_REGISTER_LAYER_CLASS(Softmax, SoftmaxLayer); CV_DNN_REGISTER_LAYER_CLASS(SoftMax, SoftmaxLayer); // For compatibility. See https://github.com/opencv/opencv/issues/16877 CV_DNN_REGISTER_LAYER_CLASS(MVN, MVNLayer); diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp new file mode 100644 index 0000000000..b7aa18d486 --- /dev/null +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp @@ -0,0 +1,262 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h). +// Here is the original license: +/* + This file is a part of ficus language project. + See ficus/LICENSE for the licensing terms +*/ + +#include "../../precomp.hpp" +#include "fast_gemm.hpp" + +#define CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY +#include "fast_gemm_kernels.simd.hpp" +#include "layers/cpu_kernels/fast_gemm_kernels.simd_declarations.hpp" +#undef CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY +#include "fast_gemm_kernels.default.hpp" + +namespace cv { namespace dnn { + +void fastGemmPackB(const Mat &B, std::vector &packed_B, bool trans, FastGemmOpt &opt) { + CV_CheckEQ(B.dims, 2, "fastGemmPackB: input mat should be two-dimensional"); + CV_CheckTypeEQ(B.type(), CV_32F, "fastGemmPackB: only float32 is supported for now"); + + auto B_shape = shape(B); + int K = B_shape[0], N = B_shape[1], ldb0 = N, ldb1 = 1; + if (trans) { + std::swap(K, N); + std::swap(ldb0, ldb1); + } + +#if CV_TRY_NEON + if (opt.use_neon) { + int size_packed_B = opt_NEON::fastGemmPackBSize(N, K); + packed_B.resize(size_packed_B); + opt_NEON::fastGemmPackBKernel(B.ptr(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); + } else +#endif +#if CV_TRY_AVX2 + if (opt.use_avx2) { + int size_packed_B = opt_AVX2::fastGemmPackBSize(N, K); + packed_B.resize(size_packed_B); + opt_AVX2::fastGemmPackBKernel(B.ptr(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); + } else +#endif +#if CV_TRY_AVX + if (opt.use_avx) { + int size_packed_B = opt_AVX::fastGemmPackBSize(N, K); + packed_B.resize(size_packed_B); + opt_AVX::fastGemmPackBKernel(B.ptr(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); + } else +#endif +#if CV_TRY_LASX + if (opt.use_lasx) { + int size_packed_B = opt_LASX::fastGemmPackBSize(N, K); + packed_B.resize(size_packed_B); + opt_LASX::fastGemmPackBKernel(B.ptr(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); + } else +#endif + { + int size_packed_B = cpu_baseline::fastGemmPackBSize(N, K); + packed_B.resize(size_packed_B); + cpu_baseline::fastGemmPackBKernel(B.ptr(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize()); + } +} + +static void fast_gemm_thin(float alpha, float beta, int M, int N, int K, + const char *a_, int lda0, int lda1, + const char *b_, int ldb, + char *c_, int ldc) { + const float* a = (const float*)a_; + + auto fn = [&](const Range &r) { + for(int start = r.start ; start < r.end; start++ ) { + float* c_i = (float*)c_ + start * ldc; + if (beta == 0.f) + for(int j = 0; j < N; j++ ) c_i[j] = 0.f; + else if (beta != 1.f) + for(int j = 0; j < N; j++ ) c_i[j] *= beta; + for(int k = 0; k < K; k++ ) { + const float* b_k = (const float*)b_ + k * ldb; + float aval = alpha * a[start * lda0 + k * lda1]; + for(int j = 0; j < N; j++ ) + c_i[j] += aval * b_k[j]; + } + } + }; + + int total = M; // outer loops + int cost_per_thread = static_cast(K * N); // inner loops + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +void fastGemm(bool trans_a, int M, int N, int K, + float alpha, const float *A, int lda, + const float *packed_B, float beta, + float *C, int ldc, FastGemmOpt &opt) { + int lda0 = lda, lda1 = 1; + if (trans_a) { + std::swap(lda0, lda1); + } + +#if CV_TRY_NEON + if (opt.use_neon) { + opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + } else +#endif +#if CV_TRY_AVX2 + if (opt.use_avx2) { + opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + } else +#endif +#if CV_TRY_AVX + if (opt.use_avx) { + opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + } else +#endif +#if CV_TRY_LASX + if (opt.use_lasx) { + opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + } else +#endif + { + cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, (const char *)packed_B, beta, (char *)C, ldc, sizeof(float)); + } +} + +void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb, + float alpha, const float *A, int lda0, int lda1, const float *B, int ldb0, int ldb1, + float beta, float *C, int ldc, FastGemmOpt &opt) { + + const char *a = (const char *)A; + const char *b = (const char *)B; + char *c = (char *)C; + + int M = trans_a ? na : ma; + int N = trans_b ? mb : nb; + int K = trans_a ? ma : na; + + if (trans_a) { + std::swap(lda0, lda1); + } + if (trans_b) { + std::swap(ldb0, ldb1); + } + + if (!trans_b && ldb1 == 1 && (M <= 4 || (uint64_t)M * N * K <= 10000)) { + return fast_gemm_thin(alpha, beta, M, N, K, a, lda0, lda1, b, ldb0, c, ldc); + } + +#if CV_TRY_NEON + if (opt.use_neon) { + opt_NEON::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, + (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + } else +#endif +#if CV_TRY_AVX2 + if (opt.use_avx2) { + opt_AVX2::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, + (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + } else +#endif +#if CV_TRY_AVX + if (opt.use_avx) { + opt_AVX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, + (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + } else +#endif +#if CV_TRY_LASX + if (opt.use_lasx) { + opt_LASX::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, + (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + } else +#endif + { + cpu_baseline::fastGemmKernel(M, N, K, alpha, (const char *)A, lda0, lda1, + (const char *)B, ldb0, ldb1, beta, (char *)C, ldc, sizeof(float)); + } +} + +void fastGemm(bool trans_a, bool trans_b, + float alpha, const Mat &A, const Mat &B, + float beta, Mat &C, FastGemmOpt &opt) { + CV_CheckTypeEQ(A.type(), CV_32F, "DNN/fastGemm: only support float32 for now"); + CV_CheckTypeEQ(A.type(), B.type(), "DNN/fastGemm: A and B should have the same type"); + CV_CheckTypeEQ(B.type(), C.type(), "DNN/fastGemm: B and C should have the same type"); + + const auto shape_a = shape(A); + CV_CheckEQ(shape_a.size(), static_cast(2), "DNN/fastGemm: A must be 2-dimensional"); + const auto shape_b = shape(B); + CV_CheckEQ(shape_b.size(), static_cast(2), "DNN/fastGemm: B must be 2-dimensional"); + const auto shape_c = shape(C); + CV_CheckEQ(shape_c.size(), static_cast(2), "DNN/fastGemm: C must be 2-dimensional"); + + int ma = shape_a[0], na = shape_a[1]; + int mb = shape_b[0], nb = shape_b[1]; + + int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1]; + + const float *a = A.ptr(); + const float *b = B.ptr(); + float *c = C.ptr(); + + fastGemm(trans_a, trans_b, ma, na, mb, nb, + alpha, a, lda0, lda1, b, ldb0, ldb1, + beta, c, ldc, opt); +} + +void fastGemmBatched(bool trans_a, bool trans_b, + float alpha, const Mat &A, const Mat &B, + float beta, Mat &C, FastGemmOpt &opt) { + CV_CheckTypeEQ(A.type(), B.type(), "DNN/fastGemmBatched: A and B should have the same type"); + CV_CheckTypeEQ(B.type(), C.type(), "DNN/fastGemmBatched: B and C should have the same type"); + CV_CheckTypeEQ(A.type(), CV_32F, "DNN/fastGemmBatched: only support float32 for now"); + + const auto shape_a = shape(A); + size_t dims_A = shape_a.size(); + CV_CheckGE(dims_A, static_cast(2), "DNN/fastGemmBatched: A must be n-dimensional (n >= 2)"); + const auto shape_b = shape(B); + CV_CheckEQ(shape_b.size(), static_cast(2), "DNN/fastGemmBatched: B must be 2-dimensional"); + const auto shape_c = shape(C); + size_t dims_C = shape_c.size(); + CV_CheckGE(dims_C, static_cast(2), "DNN/fastGemmBatched: C must be n-dimensional (n >= 2)"); + + if (trans_a) { + int ma = shape_a[dims_A - 2], na = shape_a[dims_A - 1]; + int mb = shape_b[0], nb = shape_b[1]; + + int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1]; + + const float *a = A.ptr(); + const float *b = B.ptr(); + float *c = C.ptr(); + + int batches = std::accumulate(shape_a.begin(), shape_a.end() - 2, 1, std::multiplies()); + int step_a = ma * na, step_c = na * nb; + for (int i = 0; i < batches; i++) { + fastGemm(true, trans_b, ma, na, mb, nb, + alpha, a + i * step_a, lda0, lda1, b, ldb0, ldb1, + beta, c + i * step_c, ldc, opt); + } + } else { + int ma = std::accumulate(shape_a.begin(), shape_a.end() - 1, 1, std::multiplies()), + na = shape_a[dims_A - 1]; + int mb = shape_b[0], nb = shape_b[1]; + + int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1]; + + const float *a = A.ptr(); + const float *b = B.ptr(); + float *c = C.ptr(); + + fastGemm(false, trans_b, ma, na, mb, nb, + alpha, a, lda0, lda1, b, ldb0, ldb1, + beta, c, ldc, opt); + } +} + +}} // cv::dnn diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp new file mode 100644 index 0000000000..7f9e5c3017 --- /dev/null +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm.hpp @@ -0,0 +1,65 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h). +// Here is the original license: +/* + This file is a part of ficus language project. + See ficus/LICENSE for the licensing terms +*/ + +#ifndef OPENCV_DNN_FAST_GEMM_HPP +#define OPENCV_DNN_FAST_GEMM_HPP + +#include "opencv2/core/hal/intrin.hpp" +#include + +namespace cv { namespace dnn { + +struct FastGemmOpt { + bool use_avx; + bool use_avx2; + bool use_neon; + bool use_lasx; + + FastGemmOpt() { + use_avx = false; + use_avx2 = false; + use_neon = false; + use_lasx = false; + } + + void init() { + use_avx = checkHardwareSupport(CPU_AVX); + use_avx2 = checkHardwareSupport(CPU_AVX2); + use_neon = checkHardwareSupport(CPU_NEON); + use_lasx = checkHardwareSupport(CPU_LASX); + } + + bool all() { + return use_avx || use_avx2 || use_neon || use_lasx; + } +}; + +void fastGemmPackB(const Mat &m, std::vector &packed_B, bool trans, FastGemmOpt &opt); + +void fastGemm(bool trans_a, int M, int N, int K, + float alpha, const float *A, int lda, + const float *packed_B, float beta, + float *C, int ldc, FastGemmOpt &opt); +void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb, + float alpha, const float *A, int lda0, int lda1, const float *B, int ldb0, int ldb1, + float beta, float *C, int ldc, FastGemmOpt &opt); +void fastGemm(bool trans_a, bool trans_b, + float alpha, const Mat &A, const Mat &B, + float beta, Mat &C, FastGemmOpt &opt); + +// FIXME: B needs to 2d for now. Support nd (n>=2) B in the future. +void fastGemmBatched(bool trans_a, bool trans_b, + float alpha, const Mat &A, const Mat &B, + float beta, Mat &C, FastGemmOpt &opt); + +}} // cv::dnn + +#endif // OPENCV_DNN_FAST_GEMM_HPP diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp new file mode 100644 index 0000000000..6a8ef6b590 --- /dev/null +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.default.hpp @@ -0,0 +1,393 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h). +// Here is the original license: +/* + This file is a part of ficus language project. + See ficus/LICENSE for the licensing terms +*/ + +#include +#include // parallel_for_ + +#define FAST_GEMM_DEFAULT_STORAGE (1<<20) // 2^20 +#define FAST_GEMM_DEFAULT_MAX_STACKBUF (1 << 14) + +#define FAST_GEMM_DEFAULT_F32_MC 64 +#define FAST_GEMM_DEFAULT_F32_NC 240 +#define FAST_GEMM_DEFAULT_F32_MR 8 +#define FAST_GEMM_DEFAULT_F32_NR 12 +#define FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K 256 + +#define FAST_GEMM_DEFAULT_IMPLEMENT_PACK(N, suffix, styp, dtyp) \ +static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ + int lda0, int lda1, void* packA_ ) \ +{ \ + const styp* A = (const styp*)A_; \ + dtyp* packA = (dtyp*)packA_; \ + for( int i = 0; i < m; i += N ) { \ + if (i + N-1 < m) { \ + const styp* a_ptr = A + lda0*i; \ + for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ + { \ + FAST_GEMM_DEFAULT_LOAD_TO_BUF_##N(styp); \ + FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \ + } \ + } else { \ + const styp* a_ptr[N]; \ + for (int k = 0; k < N; k++) a_ptr[k] = A + lda0*(i+k < m ? i+k : i); \ + for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ + { \ + FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_##N(styp); \ + FAST_GEMM_DEFAULT_PACK##suffix##_##N(buf, packA); \ + } \ + } \ + } \ +} + +#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_8(styp) \ + styp buf[] = { \ + a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ + a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7] } + +#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_8(styp) \ + styp buf[] = { \ + a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ + a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j] } + +#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_12(styp) \ + styp buf[] = { \ + a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ + a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \ + a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11] } + +#define FAST_GEMM_DEFAULT_LOAD_TO_BUF_BORDERS_12(styp) \ + styp buf[] = { \ + a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ + a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \ + a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j] } + +#define FAST_GEMM_DEFAULT_PACK_COPY(src, dst, N) \ + memcpy((dst), (src), N*sizeof(src[0])) +#define FAST_GEMM_DEFAULT_PACK_f32_8(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 8) +#define FAST_GEMM_DEFAULT_PACK_f32_12(src, dst) FAST_GEMM_DEFAULT_PACK_COPY((src), (dst), 12) + +namespace cv { namespace dnn { namespace cpu_baseline { + +int fastGemmPackBSize(int N, int K); + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz); + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *B, int ldb0, int ldb1, + float beta, char *C, int ldc, int esz); +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *packed_B, float beta, char *C, int ldc, int esz); + +FAST_GEMM_DEFAULT_IMPLEMENT_PACK(8, _f32, float, float) +FAST_GEMM_DEFAULT_IMPLEMENT_PACK(12, _f32, float, float) + +int fastGemmPackBSize(int N, int K) { + int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + + return static_cast((N + NC - 1) / NC) * NC * K; +} + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { + int GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); + + int n_tiles = (N + NC - 1) / NC; + for (int r = 0; r < n_tiles; ++r) { + int j0 = r * NC; + int nc = N - j0 < NC ? N - j0 : NC; + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for (int k = 0; k < K; k += KC) { + int kc = K - k < KC ? K - k : KC; + fast_gemm_pack12_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); + packed_B += _nc * kc; + } + } +} + +#if CV_SIMD128 +static void fast_gemm8x12_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + v_float32x4 s00 = v_setzero_f32(), s01 = s00, s02 = s00; + v_float32x4 s10 = s00, s11 = s00, s12 = s00; + v_float32x4 s20 = s00, s21 = s00, s22 = s00; + v_float32x4 s30 = s00, s31 = s00, s32 = s00; + v_float32x4 s40 = s00, s41 = s00, s42 = s00; + v_float32x4 s50 = s00, s51 = s00, s52 = s00; + v_float32x4 s60 = s00, s61 = s00, s62 = s00; + v_float32x4 s70 = s00, s71 = s00, s72 = s00; + + for(int p = 0; p < k; p++, a += FAST_GEMM_DEFAULT_F32_MR, b += FAST_GEMM_DEFAULT_F32_NR) { + v_float32x4 b0 = v_load(b), b1 = v_load(b + 4), b2 = v_load(b + 8); + + v_float32x4 a0 = v_setall_f32(*a); + s00 = v_fma(b0, a0, s00); + s01 = v_fma(b1, a0, s01); + s02 = v_fma(b2, a0, s02); + v_float32x4 a1 = v_setall_f32(*(a + 1)); + s10 = v_fma(b0, a1, s10); + s11 = v_fma(b1, a1, s11); + s12 = v_fma(b2, a1, s12); + + v_float32x4 a2 = v_setall_f32(*(a + 2)); + s20 = v_fma(b0, a2, s20); + s21 = v_fma(b1, a2, s21); + s22 = v_fma(b2, a2, s22); + v_float32x4 a3 = v_setall_f32(*(a + 3)); + s30 = v_fma(b0, a3, s30); + s31 = v_fma(b1, a3, s31); + s32 = v_fma(b2, a3, s32); + + a0 = v_setall_f32(*(a + 4)); + s40 = v_fma(b0, a0, s40); + s41 = v_fma(b1, a0, s41); + s42 = v_fma(b2, a0, s42); + a1 = v_setall_f32(*(a + 5)); + s50 = v_fma(b0, a1, s50); + s51 = v_fma(b1, a1, s51); + s52 = v_fma(b2, a1, s52); + + a2 = v_setall_f32(*(a + 6)); + s60 = v_fma(b0, a2, s60); + s61 = v_fma(b1, a2, s61); + s62 = v_fma(b2, a2, s62); + a3 = v_setall_f32(*(a + 7)); + s70 = v_fma(b0, a3, s70); + s71 = v_fma(b1, a3, s71); + s72 = v_fma(b2, a3, s72); + } + + v_float32x4 c0, c1, c2, c3, c4, c5, v_alpha = v_setall_f32(alpha); +#define FAST_GEMM_FINALE(row0, row1) \ + c0 = v_load(c + row0 * ldc); \ + c1 = v_load(c + row0 * ldc + 4); \ + c2 = v_load(c + row0 * ldc + 8); \ + c3 = v_load(c + row1 * ldc); \ + c4 = v_load(c + row1 * ldc + 4); \ + c5 = v_load(c + row1 * ldc + 8); \ + c0 = v_fma(s##row0##0, v_alpha, c0); \ + c1 = v_fma(s##row0##1, v_alpha, c1); \ + c2 = v_fma(s##row0##2, v_alpha, c2); \ + c3 = v_fma(s##row1##0, v_alpha, c3); \ + c4 = v_fma(s##row1##1, v_alpha, c4); \ + c5 = v_fma(s##row1##2, v_alpha, c5); \ + v_store(c + row0 * ldc, c0); \ + v_store(c + row0 * ldc + 4, c1); \ + v_store(c + row0 * ldc + 8, c2); \ + v_store(c + row1 * ldc, c3); \ + v_store(c + row1 * ldc + 4, c4); \ + v_store(c + row1 * ldc + 8, c5); + + FAST_GEMM_FINALE(0, 1); + FAST_GEMM_FINALE(2, 3); + FAST_GEMM_FINALE(4, 5); + FAST_GEMM_FINALE(6, 7); +#undef FAST_GEMM_FINALE +} + +#else +static void fast_gemm_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + float sbuf[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; + memset(sbuf, 0, sizeof(sbuf)); + for(int p = 0; p < k; p++) { + for( int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++ ) { + float ai = a[FAST_GEMM_DEFAULT_F32_MR * p + i]; + for( int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++ ) + sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j] += b[FAST_GEMM_DEFAULT_F32_NR * p + j] * ai; + } + } + for (int i = 0; i < FAST_GEMM_DEFAULT_F32_MR; i++) { + for (int j = 0; j < FAST_GEMM_DEFAULT_F32_NR; j++) + c[i * ldc + j] += alpha * sbuf[i * FAST_GEMM_DEFAULT_F32_NR + j]; + } +} +#endif // CV_SIMD128 + +static void fast_gemm_macro_kernel(int m, int n, int k, + const char *packed_A, const char *packed_B, + float alpha, char *c, int ldc0, int esz) { + int ldc0_esz = ldc0 * esz; + + double tempC[FAST_GEMM_DEFAULT_F32_MR * FAST_GEMM_DEFAULT_F32_NR]; // make sure the buffer is big enough + for(int i = 0; i < m; i += FAST_GEMM_DEFAULT_F32_MR) { + for(int j = 0; j < n; j += FAST_GEMM_DEFAULT_F32_NR) { + char* cptr0 = &c[i * ldc0_esz + j * esz]; + char* cptr = cptr0; + int ldc = ldc0; + int mr = m - i < FAST_GEMM_DEFAULT_F32_MR ? m - i : FAST_GEMM_DEFAULT_F32_MR; + int nr = n - j < FAST_GEMM_DEFAULT_F32_NR ? n - j : FAST_GEMM_DEFAULT_F32_NR; + int nr_esz = nr * esz; + bool partial = (bool)((mr < FAST_GEMM_DEFAULT_F32_MR) | (nr < FAST_GEMM_DEFAULT_F32_NR)); + if (partial) { + memset(tempC, 0, sizeof(tempC)); + cptr = (char *)tempC; + ldc = FAST_GEMM_DEFAULT_F32_NR; + for(int p = 0; p < mr; p++) + memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); + } +#if CV_SIMD128 + fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#else + fast_gemm_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#endif + + if (partial) { + for(int p = 0; p < mr; p++) + memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); + } + } + } +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *B, int ldb0, int ldb1, + float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, + GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, + GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, + GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = FAST_GEMM_DEFAULT_STORAGE / ((MC + NC) * esz); + KC = KC > 8 ? KC : 8; + KC = KC < K ? KC : K; + + size_t buff_size = KC * (MC + NC) * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); + char* packed_b = packed_a + KC * MC * esz; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); + fast_gemm_pack12_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *packed_B, float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_DEFAULT_F32_MC, + GEMM_NC = FAST_GEMM_DEFAULT_F32_NC, + GEMM_MR = FAST_GEMM_DEFAULT_F32_MR, + GEMM_NR = FAST_GEMM_DEFAULT_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_DEFAULT_F32_PACKED_STRIDE_K, K); + + size_t buff_size = KC * MC * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_DEFAULT_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer + const char *packed_b_ = packed_B; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + packed_b_ = packed_B + j0 * K * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); + packed_b_ += _nc * kc; + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +}}} // cv::dnn::cpu_baseline diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp new file mode 100644 index 0000000000..99a7d3b2d7 --- /dev/null +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm_kernels.simd.hpp @@ -0,0 +1,1059 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +// This file is modified from the ficus (https://github.com/vpisarev/ficus/blob/master/runtime/ficus/impl/gemm.impl.h). +// Here is the original license: +/* + This file is a part of ficus language project. + See ficus/LICENSE for the licensing terms +*/ + +#include +#include // parallel_for_ + +#define FAST_GEMM_STORAGE (1<<20) // 2^20 +#define FAST_GEMM_MAX_STACKBUF (1 << 14) + +#if CV_NEON +#define FAST_GEMM_F32_MC 64 +#define FAST_GEMM_F32_NC 240 +#elif CV_AVX +#define FAST_GEMM_F32_MC 60 +#define FAST_GEMM_F32_NC 320 +#elif CV_LASX +#define FAST_GEMM_F32_MC 48 +#define FAST_GEMM_F32_NC 128 +#endif + +// micro kernel size +#if CV_NEON && CV_NEON_AARCH64 +#define FAST_GEMM_F32_MR 8 +#define FAST_GEMM_F32_NR 12 +#elif CV_NEON +#define FAST_GEMM_F32_MR 4 +#define FAST_GEMM_F32_NR 12 +#elif CV_AVX +#define FAST_GEMM_F32_MR 12 +#define FAST_GEMM_F32_NR 8 +#elif CV_LASX +#define FAST_GEMM_F32_MR 12 +#define FAST_GEMM_F32_NR 16 +#endif + +#if CV_NEON +#define FAST_GEMM_F32_PACKED_STRIDE_K 64 +#elif CV_AVX +#define FAST_GEMM_F32_PACKED_STRIDE_K 128 +#elif CV_LASX +#define FAST_GEMM_F32_PACKED_STRIDE_K 64 +#endif + +#define FAST_GEMM_IMPLEMENT_PACK(N, suffix, styp, dtyp) \ +static void fast_gemm_pack##N##suffix( int m, int k, const void* A_, \ + int lda0, int lda1, void* packA_ ) \ +{ \ + const styp* A = (const styp*)A_; \ + dtyp* packA = (dtyp*)packA_; \ + for( int i = 0; i < m; i += N ) { \ + if (i + N-1 < m) { \ + const styp* a_ptr = A + lda0*i; \ + for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ + { \ + FAST_GEMM_LOAD_TO_BUF_##N(styp); \ + FAST_GEMM_PACK##suffix##_##N(buf, packA); \ + } \ + } else { \ + const styp* a_ptr[N]; \ + for (int k = 0; k < N; k++) a_ptr[k] = A + lda0*(i+k < m ? i+k : i); \ + for( int j = 0; j < k*lda1; packA += N, j += lda1 ) \ + { \ + FAST_GEMM_LOAD_TO_BUF_BORDERS_##N(styp); \ + FAST_GEMM_PACK##suffix##_##N(buf, packA); \ + } \ + } \ + } \ +} + +#define FAST_GEMM_LOAD_TO_BUF_4(styp) \ + styp buf[] = { \ + a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3] } + +#define FAST_GEMM_LOAD_TO_BUF_BORDERS_4(styp) \ + styp buf[] = { \ + a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j] } + +#define FAST_GEMM_LOAD_TO_BUF_8(styp) \ + styp buf[] = { \ + a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ + a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7] } + +#define FAST_GEMM_LOAD_TO_BUF_BORDERS_8(styp) \ + styp buf[] = { \ + a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ + a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j] } + +#define FAST_GEMM_LOAD_TO_BUF_12(styp) \ + styp buf[] = { \ + a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ + a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \ + a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11] } + +#define FAST_GEMM_LOAD_TO_BUF_BORDERS_12(styp) \ + styp buf[] = { \ + a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ + a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \ + a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j] } + +#define FAST_GEMM_LOAD_TO_BUF_16(styp) \ + styp buf[] = { \ + a_ptr[j], a_ptr[j+lda0], a_ptr[j+lda0*2], a_ptr[j+lda0*3], \ + a_ptr[j+lda0*4], a_ptr[j+lda0*5], a_ptr[j+lda0*6], a_ptr[j+lda0*7], \ + a_ptr[j+lda0*8], a_ptr[j+lda0*9], a_ptr[j+lda0*10], a_ptr[j+lda0*11], \ + a_ptr[j+lda0*12], a_ptr[j+lda0*13], a_ptr[j+lda0*14], a_ptr[j+lda0*15] } + +#define FAST_GEMM_LOAD_TO_BUF_BORDERS_16(styp) \ + styp buf[] = { \ + a_ptr[0][j], a_ptr[1][j], a_ptr[2][j], a_ptr[3][j], \ + a_ptr[4][j], a_ptr[5][j], a_ptr[6][j], a_ptr[7][j], \ + a_ptr[8][j], a_ptr[9][j], a_ptr[10][j], a_ptr[11][j], \ + a_ptr[12][j], a_ptr[13][j], a_ptr[14][j], a_ptr[15][j] } + +#define FAST_GEMM_PACK_COPY(src, dst, N) \ + memcpy((dst), (src), N*sizeof(src[0])) +#define FAST_GEMM_PACK_f32_4(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 4) +#define FAST_GEMM_PACK_f32_8(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 8) +#define FAST_GEMM_PACK_f32_12(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 12) +#define FAST_GEMM_PACK_f32_16(src, dst) FAST_GEMM_PACK_COPY((src), (dst), 16) + +namespace cv { namespace dnn { + +CV_CPU_OPTIMIZATION_NAMESPACE_BEGIN + +// TODO: type to size_t +int fastGemmPackBSize(int N, int K); + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz); + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *B, int ldb0, int ldb1, + float beta, char *C, int ldc, int esz); +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *packed_B, float beta, char *C, int ldc, int esz); + +// NEON (AARCH64: 32 x 128-bit registers, armv7: 16 x 128-bit registers) +#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_NEON + +#if CV_NEON_AARCH64 +FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) +#else +FAST_GEMM_IMPLEMENT_PACK(4, _f32, float, float) +#endif +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) + +int fastGemmPackBSize(int N, int K) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + + return static_cast((N + NC - 1) / NC) * NC * K; +} + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + int n_tiles = (N + NC - 1) / NC; + for (int r = 0; r < n_tiles; ++r) { + int j0 = r * NC; + int nc = N - j0 < NC ? N - j0 : NC; + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for (int k = 0; k < K; k += KC) { + int kc = K - k < KC ? K - k : KC; + fast_gemm_pack12_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); + packed_B += _nc * kc; + } + } +} + +#if CV_NEON_AARCH64 +static void fast_gemm8x12_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00; + float32x4_t s10 = s00, s11 = s00, s12 = s00; + float32x4_t s20 = s00, s21 = s00, s22 = s00; + float32x4_t s30 = s00, s31 = s00, s32 = s00; + float32x4_t s40 = s00, s41 = s00, s42 = s00; + float32x4_t s50 = s00, s51 = s00, s52 = s00; + float32x4_t s60 = s00, s61 = s00, s62 = s00; + float32x4_t s70 = s00, s71 = s00, s72 = s00; + + for(int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) + { + float32x4_t a0 = vld1q_f32(a); + float32x4_t b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); + + s00 = vfmaq_laneq_f32(s00, b0, a0, 0); + s01 = vfmaq_laneq_f32(s01, b1, a0, 0); + s02 = vfmaq_laneq_f32(s02, b2, a0, 0); + s10 = vfmaq_laneq_f32(s10, b0, a0, 1); + s11 = vfmaq_laneq_f32(s11, b1, a0, 1); + s12 = vfmaq_laneq_f32(s12, b2, a0, 1); + + s20 = vfmaq_laneq_f32(s20, b0, a0, 2); + s21 = vfmaq_laneq_f32(s21, b1, a0, 2); + s22 = vfmaq_laneq_f32(s22, b2, a0, 2); + s30 = vfmaq_laneq_f32(s30, b0, a0, 3); + s31 = vfmaq_laneq_f32(s31, b1, a0, 3); + s32 = vfmaq_laneq_f32(s32, b2, a0, 3); + + a0 = vld1q_f32(a + 4); + + s40 = vfmaq_laneq_f32(s40, b0, a0, 0); + s41 = vfmaq_laneq_f32(s41, b1, a0, 0); + s42 = vfmaq_laneq_f32(s42, b2, a0, 0); + s50 = vfmaq_laneq_f32(s50, b0, a0, 1); + s51 = vfmaq_laneq_f32(s51, b1, a0, 1); + s52 = vfmaq_laneq_f32(s52, b2, a0, 1); + + s60 = vfmaq_laneq_f32(s60, b0, a0, 2); + s61 = vfmaq_laneq_f32(s61, b1, a0, 2); + s62 = vfmaq_laneq_f32(s62, b2, a0, 2); + s70 = vfmaq_laneq_f32(s70, b0, a0, 3); + s71 = vfmaq_laneq_f32(s71, b1, a0, 3); + s72 = vfmaq_laneq_f32(s72, b2, a0, 3); + } + + float32x4_t c0, c1, c2, c3, c4, c5, v_alpha = vdupq_n_f32(alpha); +#define FAST_GEMM_FINALE(row0, row1) \ + c0 = vld1q_f32(c + row0 * ldc); \ + c1 = vld1q_f32(c + row0 * ldc + 4); \ + c2 = vld1q_f32(c + row0 * ldc + 8); \ + c3 = vld1q_f32(c + row1 * ldc); \ + c4 = vld1q_f32(c + row1 * ldc + 4); \ + c5 = vld1q_f32(c + row1 * ldc + 8); \ + c0 = vfmaq_f32(c0, s##row0##0, v_alpha); \ + c1 = vfmaq_f32(c1, s##row0##1, v_alpha); \ + c2 = vfmaq_f32(c2, s##row0##2, v_alpha); \ + c3 = vfmaq_f32(c3, s##row1##0, v_alpha); \ + c4 = vfmaq_f32(c4, s##row1##1, v_alpha); \ + c5 = vfmaq_f32(c5, s##row1##2, v_alpha); \ + vst1q_f32(c + row0 * ldc, c0); \ + vst1q_f32(c + row0 * ldc + 4, c1); \ + vst1q_f32(c + row0 * ldc + 8, c2); \ + vst1q_f32(c + row1 * ldc, c3); \ + vst1q_f32(c + row1 * ldc + 4, c4); \ + vst1q_f32(c + row1 * ldc + 8, c5); + + FAST_GEMM_FINALE(0, 1); + FAST_GEMM_FINALE(2, 3); + FAST_GEMM_FINALE(4, 5); + FAST_GEMM_FINALE(6, 7); +#undef FAST_GEMM_FINALE +} + +#else // CV_NEON_AARCH64 +static void fast_gemm4x12_f32(int k, const char *a_, const char *b_, + char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + float32x4_t s00 = vdupq_n_f32(0.f), s01 = s00, s02 = s00, + s10 = s00, s11 = s00, s12 = s00, + s20 = s00, s21 = s00, s22 = s00, + s30 = s00, s31 = s00, s32 = s00; + + for(int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) + { + float32x4_t b0 = vld1q_f32(b), b1 = vld1q_f32(b + 4), b2 = vld1q_f32(b + 8); + + float32x4_t a0 = vld1q_dup_f32(a); + s00 = vmlaq_f32(a0, b0, s00); + s01 = vmlaq_f32(a0, b1, s01); + s02 = vmlaq_f32(a0, b2, s02); + + a0 = vld1q_dup_f32(a + 1); + s10 = vmlaq_f32(a0, b0, s10); + s11 = vmlaq_f32(a0, b1, s11); + s12 = vmlaq_f32(a0, b2, s12); + + a0 = vld1q_dup_f32(a + 2); + s20 = vmlaq_f32(a0, b0, s20); + s21 = vmlaq_f32(a0, b1, s21); + s22 = vmlaq_f32(a0, b2, s22); + + a0 = vld1q_dup_f32(a + 3); + s30 = vmlaq_f32(a0, b0, s30); + s31 = vmlaq_f32(a0, b1, s31); + s32 = vmlaq_f32(a0, b2, s32); + } + + float32x4_t c0, c1, c2, v_alpha = vdupq_n_f32(alpha); +#define FAST_GEMM_FINALE(row0) \ + c0 = vld1q_f32(c + row0 * ldc); \ + c1 = vld1q_f32(c + row0 * ldc + 4); \ + c2 = vld1q_f32(c + row0 * ldc + 8); \ + c0 = vmlaq_f32(c0, s##row0##0, v_alpha); \ + c1 = vmlaq_f32(c1, s##row0##1, v_alpha); \ + c2 = vmlaq_f32(c2, s##row0##2, v_alpha); \ + vst1q_f32(c + row0 * ldc, c0); \ + vst1q_f32(c + row0 * ldc + 4, c1); \ + vst1q_f32(c + row0 * ldc + 8, c2); + + FAST_GEMM_FINALE(0); + FAST_GEMM_FINALE(1); + FAST_GEMM_FINALE(2); + FAST_GEMM_FINALE(3); +#undef FAST_GEMM_FINALE +} + +#endif // micro kernel CV_NEON_AARCH64 + +static void fast_gemm_macro_kernel(int m, int n, int k, + const char *packed_A, const char *packed_B, + float alpha, char *c, int ldc0, int esz) { + int ldc0_esz = ldc0 * esz; + + double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough + for(int i = 0; i < m; i += FAST_GEMM_F32_MR) { + for(int j = 0; j < n; j += FAST_GEMM_F32_NR) { + char* cptr0 = &c[i * ldc0_esz + j * esz]; + char* cptr = cptr0; + int ldc = ldc0; + int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR; + int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR; + int nr_esz = nr * esz; + bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR)); + if (partial) { + memset(tempC, 0, sizeof(tempC)); + cptr = (char *)tempC; + ldc = FAST_GEMM_F32_NR; + for(int p = 0; p < mr; p++) + memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); + } +#if CV_NEON_AARCH64 + fast_gemm8x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#else + fast_gemm4x12_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); +#endif + + if (partial) { + for(int p = 0; p < mr; p++) + memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); + } + } + } +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *B, int ldb0, int ldb1, + float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz); + KC = KC > 8 ? KC : 8; + KC = KC < K ? KC : K; + + size_t buff_size = KC * (MC + NC) * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); + char* packed_b = packed_a + KC * MC * esz; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; +#if CV_NEON_AARCH64 + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#else + fast_gemm_pack4_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#endif + fast_gemm_pack12_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *packed_B, float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + size_t buff_size = KC * MC * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer + const char *packed_b_ = packed_B; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + packed_b_ = packed_B + j0 * K * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; +#if CV_NEON_AARCH64 + fast_gemm_pack8_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#else + fast_gemm_pack4_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); +#endif + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); + packed_b_ += _nc * kc; + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +#endif // CV_NEON, CV_NEON_AARCH64 + +// AVX and AVX2 (16 x 256-bit registers) +#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_AVX + +FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float) +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) + +int fastGemmPackBSize(int N, int K) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + + return static_cast((N + NC - 1) / NC) * NC * K; +} + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + int n_tiles = (N + NC - 1) / NC; + for (int r = 0; r < n_tiles; ++r) { + int j0 = r * NC; + int nc = N - j0 < NC ? N - j0 : NC; + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for (int k = 0; k < K; k += KC) { + int kc = K - k < KC ? K - k : KC; + fast_gemm_pack8_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); + packed_B += _nc * kc; + } + } +} + +#if !CV_FMA3 // AVX workaround for FMA +#undef _mm256_fmadd_ps +#define _mm256_fmadd_ps(a, b, c) _mm256_add_ps(c, _mm256_mul_ps(a, b)) +#endif + +static void fast_gemm12x8_f32(int k, const char *a_, const char *b_, char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + __m256 s00 = _mm256_setzero_ps(), + s10 = _mm256_setzero_ps(), + s20 = _mm256_setzero_ps(), + s30 = _mm256_setzero_ps(), + s40 = _mm256_setzero_ps(), + s50 = _mm256_setzero_ps(), + s60 = _mm256_setzero_ps(), + s70 = _mm256_setzero_ps(), + s80 = _mm256_setzero_ps(), + s90 = _mm256_setzero_ps(), + s100 = _mm256_setzero_ps(), + s110 = _mm256_setzero_ps(); + for (int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) { + __m256 b0 = _mm256_loadu_ps(b); + + __m256 a0 = _mm256_set1_ps(*a); + s00 = _mm256_fmadd_ps(b0, a0, s00); + __m256 a1 = _mm256_set1_ps(*(a + 1)); + s10 = _mm256_fmadd_ps(b0, a1, s10); + __m256 a2 = _mm256_set1_ps(*(a + 2)); + s20 = _mm256_fmadd_ps(b0, a2, s20); + + a0 = _mm256_set1_ps(*(a + 3)); + s30 = _mm256_fmadd_ps(b0, a0, s30); + a1 = _mm256_set1_ps(*(a + 4)); + s40 = _mm256_fmadd_ps(b0, a1, s40); + a2 = _mm256_set1_ps(*(a + 5)); + s50 = _mm256_fmadd_ps(b0, a2, s50); + + a0 = _mm256_set1_ps(*(a + 6)); + s60 = _mm256_fmadd_ps(b0, a0, s60); + a1 = _mm256_set1_ps(*(a + 7)); + s70 = _mm256_fmadd_ps(b0, a1, s70); + a2 = _mm256_set1_ps(*(a + 8)); + s80 = _mm256_fmadd_ps(b0, a2, s80); + + a0 = _mm256_set1_ps(*(a + 9)); + s90 = _mm256_fmadd_ps(b0, a0, s90); + a1 = _mm256_set1_ps(*(a + 10)); + s100 = _mm256_fmadd_ps(b0, a1, s100); + a2 = _mm256_set1_ps(*(a + 11)); + s110 = _mm256_fmadd_ps(b0, a2, s110); + } + + __m256 c0, c1, c2, c3, v_alpha = _mm256_set1_ps(alpha); +#define FAST_GEMM_FINALE(row0, row1, row2, row3) \ + c0 = _mm256_loadu_ps(c + row0 * ldc); \ + c1 = _mm256_loadu_ps(c + row1 * ldc); \ + c2 = _mm256_loadu_ps(c + row2 * ldc); \ + c3 = _mm256_loadu_ps(c + row3 * ldc); \ + c0 = _mm256_fmadd_ps(s##row0##0, v_alpha, c0); \ + c1 = _mm256_fmadd_ps(s##row1##0, v_alpha, c1); \ + c2 = _mm256_fmadd_ps(s##row2##0, v_alpha, c2); \ + c3 = _mm256_fmadd_ps(s##row3##0, v_alpha, c3); \ + _mm256_storeu_ps(c + row0 * ldc, c0); \ + _mm256_storeu_ps(c + row1 * ldc, c1); \ + _mm256_storeu_ps(c + row2 * ldc, c2); \ + _mm256_storeu_ps(c + row3 * ldc, c3); \ + + FAST_GEMM_FINALE(0, 1, 2, 3); + FAST_GEMM_FINALE(4, 5, 6, 7); + FAST_GEMM_FINALE(8, 9, 10, 11); +#undef FAST_GEMM_FINALE +} + +static void fast_gemm_macro_kernel(int m, int n, int k, + const char *packed_A, const char *packed_B, + float alpha, char *c, int ldc0, int esz) { + int ldc0_esz = ldc0 * esz; + + double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough + for(int i = 0; i < m; i += FAST_GEMM_F32_MR) { + for(int j = 0; j < n; j += FAST_GEMM_F32_NR) { + char* cptr0 = &c[i * ldc0_esz + j * esz]; + char* cptr = cptr0; + int ldc = ldc0; + int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR; + int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR; + int nr_esz = nr * esz; + bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR)); + if (partial) { + memset(tempC, 0, sizeof(tempC)); + cptr = (char *)tempC; + ldc = FAST_GEMM_F32_NR; + for(int p = 0; p < mr; p++) + memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); + } + fast_gemm12x8_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); + + if (partial) { + for(int p = 0; p < mr; p++) + memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); + } + } + } +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *B, int ldb0, int ldb1, + float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz); + KC = KC > 8 ? KC : 8; + KC = KC < K ? KC : K; + + size_t buff_size = KC * (MC + NC) * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); + char* packed_b = packed_a + KC * MC * esz; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; + fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); + fast_gemm_pack8_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *packed_B, float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + size_t buff_size = KC * MC * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer + const char *packed_b_ = packed_B; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + packed_b_ = packed_B + j0 * K * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; + fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); + packed_b_ += _nc * kc; + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +#endif // CV_AVX, CV_AVX2 + +// LASX (32 x 256-bit registers) +#if !defined(CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY) && CV_LASX + +FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float) +FAST_GEMM_IMPLEMENT_PACK(16, _f32, float, float) + +int fastGemmPackBSize(int N, int K) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + + return static_cast((N + NC - 1) / NC) * NC * K; +} + +void fastGemmPackBKernel(const char *B, char *packed_B, int N, int K, int ldb0, int ldb1, int esz) { + int GEMM_NC = FAST_GEMM_F32_NC, GEMM_NR = FAST_GEMM_F32_NR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + int n_tiles = (N + NC - 1) / NC; + for (int r = 0; r < n_tiles; ++r) { + int j0 = r * NC; + int nc = N - j0 < NC ? N - j0 : NC; + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for (int k = 0; k < K; k += KC) { + int kc = K - k < KC ? K - k : KC; + fast_gemm_pack16_f32(nc, kc, B + (k * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_B); + packed_B += _nc * kc; + } + } +} + +static void fast_gemm12x16_f32(int k, const char *a_, const char *b_, char *c_, int ldc, float alpha) { + const float* a = (const float*)a_; + const float* b = (const float*)b_; + float* c = (float*)c_; + + __m256i dummy; + __m256 s00 = (__m256)__lasx_xvxor_v(dummy, dummy), s01 = s00, + s10 = s00, s11 = s00, + s20 = s00, s21 = s00, + s30 = s00, s31 = s00, + s40 = s00, s41 = s00, + s50 = s00, s51 = s00, + s60 = s00, s61 = s00, + s70 = s00, s71 = s00, + s80 = s00, s81 = s00, + s90 = s00, s91 = s00, + s100 = s00, s101 = s00, + s110 = s00, s111 = s00; + for (int p = 0; p < k; p++, a += FAST_GEMM_F32_MR, b += FAST_GEMM_F32_NR) { + __m256 b0 = (__m256)__lasx_xvld(b, 0), b1 = (__m256)__lasx_xvld(b + 8, 0); + + __m256 a0 = _v256_setall_ps(*a); + s00 = __lasx_xvfmadd_s(b0, a0, s00); + s01 = __lasx_xvfmadd_s(b1, a0, s01); + __m256 a1 = _v256_setall_ps(*(a + 1)); + s10 = __lasx_xvfmadd_s(b0, a1, s10); + s11 = __lasx_xvfmadd_s(b1, a1, s11); + __m256 a2 = _v256_setall_ps(*(a + 2)); + s20 = __lasx_xvfmadd_s(b0, a2, s20); + s21 = __lasx_xvfmadd_s(b1, a2, s21); + __m256 a3 = _v256_setall_ps(*(a + 3)); + s30 = __lasx_xvfmadd_s(b0, a3, s30); + s31 = __lasx_xvfmadd_s(b1, a3, s31); + + a0 = _v256_setall_ps(*(a + 4)); + s40 = __lasx_xvfmadd_s(b0, a0, s40); + s41 = __lasx_xvfmadd_s(b1, a0, s41); + a1 = _v256_setall_ps(*(a + 5)); + s50 = __lasx_xvfmadd_s(b0, a1, s50); + s51 = __lasx_xvfmadd_s(b1, a1, s51); + a2 = _v256_setall_ps(*(a + 6)); + s60 = __lasx_xvfmadd_s(b0, a2, s60); + s61 = __lasx_xvfmadd_s(b1, a2, s61); + a3 = _v256_setall_ps(*(a + 7)); + s70 = __lasx_xvfmadd_s(b0, a3, s70); + s71 = __lasx_xvfmadd_s(b1, a3, s71); + + a0 = _v256_setall_ps(*(a + 8)); + s80 = __lasx_xvfmadd_s(b0, a0, s80); + s81 = __lasx_xvfmadd_s(b1, a0, s81); + a1 = _v256_setall_ps(*(a + 9)); + s90 = __lasx_xvfmadd_s(b0, a1, s90); + s91 = __lasx_xvfmadd_s(b1, a1, s91); + a2 = _v256_setall_ps(*(a + 10)); + s100 = __lasx_xvfmadd_s(b0, a2, s100); + s101 = __lasx_xvfmadd_s(b1, a2, s101); + a3 = _v256_setall_ps(*(a + 11)); + s110 = __lasx_xvfmadd_s(b0, a3, s110); + s111 = __lasx_xvfmadd_s(b1, a3, s111); + } + + __m256 c0, c1, c2, c3, c4, c5, c6, c7, v_alpha = _v256_setall_ps(alpha); +#define FAST_GEMM_FINALE(row0, row1, row2, row3) \ + c0 = (__m256)__lasx_xvld(c + row0 * ldc, 0); \ + c1 = (__m256)__lasx_xvld(c + row0 * ldc, 8 * 4); \ + c2 = (__m256)__lasx_xvld(c + row1 * ldc, 0); \ + c3 = (__m256)__lasx_xvld(c + row1 * ldc, 8 * 4); \ + c4 = (__m256)__lasx_xvld(c + row2 * ldc, 0); \ + c5 = (__m256)__lasx_xvld(c + row2 * ldc, 8 * 4); \ + c6 = (__m256)__lasx_xvld(c + row3 * ldc, 0); \ + c7 = (__m256)__lasx_xvld(c + row3 * ldc, 8 * 4); \ + c0 = __lasx_xvfmadd_s(s##row0##0, v_alpha, c0); \ + c1 = __lasx_xvfmadd_s(s##row0##1, v_alpha, c1); \ + c2 = __lasx_xvfmadd_s(s##row1##0, v_alpha, c2); \ + c3 = __lasx_xvfmadd_s(s##row1##1, v_alpha, c3); \ + c4 = __lasx_xvfmadd_s(s##row2##0, v_alpha, c4); \ + c5 = __lasx_xvfmadd_s(s##row2##1, v_alpha, c5); \ + c6 = __lasx_xvfmadd_s(s##row3##0, v_alpha, c6); \ + c7 = __lasx_xvfmadd_s(s##row3##1, v_alpha, c7); \ + __lasx_xvst(c0, c + row0 * ldc, 0); \ + __lasx_xvst(c1, c + row0 * ldc, 8 * 4); \ + __lasx_xvst(c2, c + row1 * ldc, 0); \ + __lasx_xvst(c3, c + row1 * ldc, 8 * 4); \ + __lasx_xvst(c4, c + row2 * ldc, 0); \ + __lasx_xvst(c5, c + row2 * ldc, 8 * 4); \ + __lasx_xvst(c6, c + row3 * ldc, 0); \ + __lasx_xvst(c7, c + row3 * ldc, 8 * 4); + + FAST_GEMM_FINALE(0, 1, 2, 3); + FAST_GEMM_FINALE(4, 5, 6, 7); + FAST_GEMM_FINALE(8, 9, 10, 11); +#undef FAST_GEMM_FINALE +} + +static void fast_gemm_macro_kernel(int m, int n, int k, + const char *packed_A, const char *packed_B, + float alpha, char *c, int ldc0, int esz) { + int ldc0_esz = ldc0 * esz; + + double tempC[FAST_GEMM_F32_MR * FAST_GEMM_F32_NR]; // make sure the buffer is big enough + for(int i = 0; i < m; i += FAST_GEMM_F32_MR) { + for(int j = 0; j < n; j += FAST_GEMM_F32_NR) { + char* cptr0 = &c[i * ldc0_esz + j * esz]; + char* cptr = cptr0; + int ldc = ldc0; + int mr = m - i < FAST_GEMM_F32_MR ? m - i : FAST_GEMM_F32_MR; + int nr = n - j < FAST_GEMM_F32_NR ? n - j : FAST_GEMM_F32_NR; + int nr_esz = nr * esz; + bool partial = (bool)((mr < FAST_GEMM_F32_MR) | (nr < FAST_GEMM_F32_NR)); + if (partial) { + memset(tempC, 0, sizeof(tempC)); + cptr = (char *)tempC; + ldc = FAST_GEMM_F32_NR; + for(int p = 0; p < mr; p++) + memcpy(cptr + p * (ldc * esz), cptr0 + p * ldc0_esz, nr_esz); + } + fast_gemm12x16_f32(k, packed_A + i * k * esz, packed_B + j * k * esz, cptr, ldc, alpha); + + if (partial) { + for(int p = 0; p < mr; p++) + memcpy(cptr0 + p * ldc0_esz, cptr + p * (ldc * esz), nr_esz); + } + } + } +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *B, int ldb0, int ldb1, + float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = FAST_GEMM_STORAGE / ((MC + NC) * esz); + KC = KC > 8 ? KC : 8; + KC = KC < K ? KC : K; + + size_t buff_size = KC * (MC + NC) * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); + char* packed_b = packed_a + KC * MC * esz; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; + fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); + fast_gemm_pack16_f32(nc, kc, B + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz); + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +void fastGemmKernel(int M, int N, int K, + float alpha, const char *A, int lda0, int lda1, + const char *packed_B, float beta, char *C, int ldc, int esz) { + int GEMM_MC = FAST_GEMM_F32_MC, + GEMM_NC = FAST_GEMM_F32_NC, + GEMM_MR = FAST_GEMM_F32_MR, + GEMM_NR = FAST_GEMM_F32_NR; + + int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR; + int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR; + int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K); + + size_t buff_size = KC * MC * esz; + bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF; + int m_tiles = (M + MC - 1) / MC; + int n_tiles = (N + NC - 1) / NC; + int total_tiles = m_tiles * n_tiles; + + auto fn = [&](const Range &r) { + char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size)); // TODO: use AutoBuffer + const char *packed_b_ = packed_B; + int start = r.start; + int end = r.end; + + for (int tile_idx = start; tile_idx < end; tile_idx++) { + int i0 = (tile_idx / n_tiles) * MC; + int j0 = (tile_idx % n_tiles) * NC; + int mc = M - i0 < MC ? M - i0 : MC; + int nc = N - j0 < NC ? N - j0 : NC; + int ldc_block = ldc; + char* c_block = C + (i0 * ldc + j0) * esz; + packed_b_ = packed_B + j0 * K * esz; + + if (beta == 0.f) { + for(int i = 0; i < mc; i++) + memset(c_block + i * ldc_block * esz, 0, nc * esz); + } else if (beta != 1.f) { + for(int i = 0; i < mc; i++) { + float* c_i = (float*)c_block + i * ldc_block; + for(int j = 0; j < nc; j++) + c_i[j] *= beta; + } + } + + int _nc = static_cast((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz; + for(int k0 = 0; k0 < K; k0 += KC) + { + int kc = K - k0 < KC ? K - k0 : KC; + fast_gemm_pack12_f32(mc, kc, A + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a); + fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b_, alpha, c_block, ldc_block, esz); + packed_b_ += _nc * kc; + } + } + + if (!use_stackbuff) { + free(packed_a); + } + }; + + int total = total_tiles; + int cost_per_thread = static_cast((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR)); + double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0); + parallel_for_(Range(0, total), fn, nstripes); +} + +#endif // CV_LASX + +CV_CPU_OPTIMIZATION_NAMESPACE_END + +}} // cv::dnn diff --git a/modules/dnn/src/layers/gemm_layer.cpp b/modules/dnn/src/layers/gemm_layer.cpp new file mode 100644 index 0000000000..9aa3b1a238 --- /dev/null +++ b/modules/dnn/src/layers/gemm_layer.cpp @@ -0,0 +1,361 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#include "../precomp.hpp" +#include "layers_common.hpp" +// backends +#include "../op_cuda.hpp" +#ifdef HAVE_CUDA +// #include "../cuda4dnn/primitives/matmul.hpp" +#include "../cuda4dnn/primitives/inner_product.hpp" +using namespace cv::dnn::cuda4dnn; +#endif +#include "../op_cann.hpp" +#include "../ie_ngraph.hpp" +#include "../op_vkcom.hpp" + +#include +#include "cpu_kernels/fast_gemm.hpp" + +namespace cv { namespace dnn { + +class GemmLayerImpl CV_FINAL : public GemmLayer { +public: + GemmLayerImpl(const LayerParams& params) { + setParamsFrom(params); + + trans_a = params.get("transA", false); + trans_b = params.get("transB", false); + alpha = params.get("alpha", 1.0f); + beta = params.get("beta", 1.0f); + + const_B = params.get("constB", false); // true means blobs[0] is B + const_C = params.get("constC", false); // true means blobs.back() is C + have_bias = params.get("have_bias", false); // NOTE: have_bias being true does not mean bias is constant + + real_ndims_C = params.get("real_ndims_C", -1); + } + + virtual bool supportBackend(int backendId) CV_OVERRIDE { + return backendId == DNN_BACKEND_OPENCV || + (backendId == DNN_BACKEND_CUDA && const_B && !trans_a) || + backendId == DNN_BACKEND_CANN || + backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH || + (backendId == DNN_BACKEND_VKCOM && haveVulkan() && !have_bias && !trans_a); + } + + virtual bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE { + int num_inputs = static_cast(inputs.size() + blobs.size()); + CV_CheckGE(num_inputs, 2, "DNN/Gemm: Gemm takes at least two inputs"); + CV_CheckLE(num_inputs, 3, "DNN/Gemm: Gemm takes at most three inputs"); + + // Check whether A and B are two dimensional + const auto shape_A = inputs[0]; + const auto shape_B = const_B ? shape(blobs[0]) : inputs[1]; + CV_CheckGE(shape_A.size(), static_cast(2), "DNN/Gemm: Tensor A must be n-dimensional (n >= 2)"); + CV_CheckEQ(shape_B.size(), static_cast(2), "DNN/Gemm: Tensor B must be two dimensional"); + + // Check legal matrix multiplication + size_t dims_A = shape_A.size(); + int ma = shape_A[dims_A - 2], na = shape_A[dims_A - 1]; + int mb = shape_B[0], nb = shape_B[1]; + int M = trans_a ? na : ma; + int N = trans_b ? mb : nb; + int K_a = trans_a ? ma : na; + int K_b = trans_b ? nb : mb; + CV_CheckEQ(K_a, K_b, "DNN/Gemm: Invalid dimension of dim K"); + + // Check whether C can be unidirectional broadcast to (M, N). Handle carefully with 1D Mat. + if (have_bias) { + const auto shape_C = const_C ? shape(blobs.back()) : inputs.back(); + + auto ndims_C = shape_C.size(); + CV_CheckLE(ndims_C, static_cast(2), "DNN/Gemm: C can only be 0d (scalar) / 1d / 2d tensor"); + + if (real_ndims_C == 1) { // (1,) or (N,) + CV_Check(shape_C[0], shape_C[0] == 1 || shape_C[0] == N, "DNN/Gemm: invalid dimension of C"); + } else if (real_ndims_C == 2) { // (1, 1) or (1, N) or (M, 1) or (M, N) + // printf("shape_C=[%d, %d]\n", shape_C[0], shape_C[1]); + CV_Check(shape_C[0], (shape_C[0] == 1 && shape_C[1] == 1) || + (shape_C[0] == 1 && shape_C[1] == N) || + (shape_C[0] == M && shape_C[1] == 1) || + (shape_C[0] == M && shape_C[1] == N), + "DNN/Gemm: C must be of shape (1, 1) or (1, N) or (M, 1) or (M, N)"); + if (shape_C[0] == 1) { + CV_Check(shape_C[1], shape_C[1] == 1 || shape_C[1] == N, "DNN/Gemm: invalid dimension of C"); + } else if (shape_C[0] == M) { + CV_Check(shape_C[1], shape_C[1] == 1 || shape_C[1] == N, "DNN/Gemm: invalid dimension of C"); + } else { + CV_Error(Error::StsBadSize, "DNN/Gemm: invalid dimension of C"); + } + } + } + + int batches = std::accumulate(shape_A.begin(), shape_A.end() - 2, 1, std::multiplies()); + MatShape shape_y{M * batches, N}; + outputs.assign(1, shape_y); + return false; + } + + // TODO: replace with cv::broadcast() once 1d mat is supported + // FIXME: fix if conditions if 1d mat is supported properly + void broadcastCWtihBeta(int M, int N, const Mat &C) { + if (beta != 0 && !C.empty()) { + broadcast_C.clear(); + broadcast_C.resize(M * N, 0.f); + + const float *ptr_c = C.ptr(); + const auto shape_C = shape(C); + if ((real_ndims_C == 0) || (real_ndims_C == 1 && shape_C[0] == 1) || + (real_ndims_C == 2 && shape_C[0] == 1 && shape_C[1] == 1)) { + // (), (1,), (1, 1) + float c = *ptr_c; + int total = M * N; + for (int i = 0; i < total; ++i) { + broadcast_C[i] = beta * c; + } + } else if ((real_ndims_C == 1 && shape_C[0] == N) || + (real_ndims_C == 2 && shape_C[0] == 1 && shape_C[1] == N)) { + // (N,), (1, N) + for (int i = 0; i < M; ++i) { + int step = i * N; + for (int j = 0; j < N; ++j) { + broadcast_C[step + j] = beta * ptr_c[j]; + } + } + } else if (real_ndims_C == 2 && shape_C[0] == M && shape_C[1] == 1) { + // (M, 1) + for (int i = 0; i < M; ++i) { + int step = i * N; + for (int j = 0; j < N; ++j) { + broadcast_C[step + j] = beta * ptr_c[i]; + } + } + } else { + // (M, N) + std::transform(ptr_c, ptr_c + M * N, broadcast_C.begin(), [this] (const float &c) { + return this->beta * c; }); + } + } + } + + virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE { + opt.init(); + + // pack B if it is const + if (const_B) { + fastGemmPackB(blobs[0], packed_B, trans_b, opt); + } + + // also pre-broadcast bias + if (const_C) { + const auto &C = blobs.back(); + + std::vector outputs; + outputs_arr.getMatVector(outputs); + const auto &Y = outputs[0]; + const auto shape_Y = shape(Y); + size_t dims_Y = shape_Y.size(); + int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1]; + + // broadcast + broadcastCWtihBeta(M, N, C); + } + } + + // Y = A * B + C, note that C is unidirectionaly broadcastable to (A * B). + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + if (inputs_arr.depth() == CV_16S) + { + forward_fallback(inputs_arr, outputs_arr, internals_arr); + return; + } + + std::vector inputs, outputs; + inputs_arr.getMatVector(inputs); + outputs_arr.getMatVector(outputs); + + const auto &A = inputs[0]; + auto &Y = outputs[0]; + + const auto shape_A = shape(A), shape_Y = shape(Y); + size_t dims_A = shape_A.size(); + int ma = shape_A[dims_A - 2], na = shape_A[dims_A - 1]; + size_t dims_Y = shape_Y.size(); + int M = shape_Y[dims_Y - 2], N = shape_Y[dims_Y - 1]; + int K = trans_a ? ma : na; + int batches = std::accumulate(shape_A.begin(), shape_A.end() - 2, 1, std::multiplies()); + + // broadcast C and copy C to output + if (have_bias) { + if (!const_C) { + broadcastCWtihBeta(M, N, inputs.back()); + } + int step = M * N; + CV_CheckEQ(broadcast_C.size(), static_cast(step), "DNN/Gemm: C is not broadcast properly"); + float *ptr_y = Y.ptr(); + for (int i = 0; i < batches; i++) { + std::memcpy(ptr_y + i * step, broadcast_C.data(), step * sizeof(float)); + } + } else { // initialization + float *ptr_y = Y.ptr(); + size_t total = Y.total(); + std::memset(ptr_y, 0, total * sizeof(float)); + } + + if (const_B) { + CV_CheckGT(packed_B.size(), static_cast(0), "DNN/Gemm: constant B is not pre-packed"); + M *= batches; + fastGemm(trans_a, M, N, K, alpha, A.ptr(), na, packed_B.data(), 1.f, Y.ptr(), N, opt); + } else { + fastGemmBatched(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt); + } + } + +#ifdef HAVE_CUDA + // Y = A * B + C. B should be guaranteed as two dimensional. + Ptr initCUDA(void *context_, + const std::vector>& inputs, + const std::vector>& outputs) CV_OVERRIDE { + CV_CheckFalse(trans_a, "DNN/Gemm/Cuda: does not support transA"); + CV_CheckTrue(const_B, "DNN/Gemm/Cuda: input B (weight) is required to be constant"); + auto context = reinterpret_cast(context_); + auto wrapper_A = inputs[0].dynamicCast(); + auto B = blobs[0]; + auto C = have_bias && const_C ? blobs[1] : Mat(); // in most cases C is constant + + if (!trans_b) + cv::transpose(B, B); + auto flatten_start_axis = normalize_axis(1, wrapper_A->getRank()); + return make_cuda_node(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, B, C); + } +#endif // HAVE_CUDA + +#ifdef HAVE_CANN + // Y = A * B + C. + virtual Ptr initCann(const std::vector > &inputs, + const std::vector > &outputs, + const std::vector >& nodes) CV_OVERRIDE { + auto x1 = inputs[0].dynamicCast(); + auto desc_x1 = x1->getTensorDesc(); + auto op_x1 = nodes[0].dynamicCast()->getOp(); + + auto op = std::make_shared(name); + + // set attributes + op->set_attr_transpose_x1(trans_a); + op->set_attr_transpose_x2(trans_b); + + // set inputs + // set inputs : x1 + op->set_input_x1_by_name(*op_x1, x1->name.c_str()); + op->update_input_desc_x1(*desc_x1); + // set inputs : x2 + if (const_B) { + auto B = blobs[0]; + auto op_const_B = std::make_shared(B.data, B.type(), shape(B), cv::format("%s_w", name.c_str())); + op->set_input_x2_by_name(*(op_const_B->getOp()), "y"); + op->update_input_desc_x2(*(op_const_B->getTensorDesc())); + } else { + CV_CheckGE(inputs.size(), static_cast(2), "DNN/Gemm/CANN: input B is required since it is not constant"); + CV_CheckGE(nodes.size(), static_cast(2), "DNN/Gemm/CANN: input B is required since it is not constant"); + auto op_x2 = nodes[1].dynamicCast()->getOp(); + auto desc_x2 = inputs[1].dynamicCast()->getTensorDesc(); + op->set_input_x2_by_name(*op_x2, "y"); + op->update_input_desc_x2(*desc_x2); + } + // set inputs : bias + auto mat_C = have_bias && const_C ? blobs.back() : Mat::zeros(1, 1, CV_32F); + auto op_const_C = std::make_shared(mat_C.data, mat_C.type(), shape(mat_C), cv::format("%s_b", name.c_str())); + op->set_input_bias(*(op_const_C->getOp())); + op->update_input_desc_bias(*(op_const_C->getTensorDesc())); + + // set outputs + op->update_output_desc_y(*output_desc); + return Ptr(new CannBackendNode(op)); + } +#endif // HAVE_CANN + +#ifdef HAVE_DNN_NGRAPH + virtual Ptr initNgraph(const std::vector >& inputs, + const std::vector >& nodes) CV_OVERRIDE + { + auto& ieInpNode = nodes[0].dynamicCast()->node; + std::shared_ptr matmul; + int axis = -2; + + if (nodes.size() == 2) + { + auto& inp2 = nodes[1].dynamicCast()->node; + matmul = std::make_shared(ieInpNode, inp2, transA, transB); + } + else + { + std::vector shape(1 + normalize_axis(axis, ieInpNode->get_shape().size()), 0); + shape[shape.size() - 1] = -1; + auto inp = std::make_shared( + ieInpNode, + std::make_shared(ngraph::element::i32, ngraph::Shape{shape.size()}, shape.data()), + true + ); + + std::vector weight_shape{(size_t)blobs[0].size[0], (size_t)blobs[0].size[1]}; + auto ieWeights = std::make_shared(ngraph::element::f32, weight_shape, blobs[0].data); + matmul = std::make_shared(inp, ieWeights, transA, transB); + } + + if (have_bias && const_C) { + auto bias_node = std::make_shared(ngraph::element::f32, + ngraph::Shape{(size_t)blobs.back().size[1]}, blobs.back().data); + matmul = std::make_shared(matmul, bias_node, ngraph::op::AutoBroadcastType::NUMPY); + } + return Ptr(new InfEngineNgraphNode(matmul)); + } +#endif // HAVE_DNN_NGRAPH + +#ifdef HAVE_VULKAN + // Y = A * B + C. Currently support 2d matrix multiplication without bias. + virtual Ptr initVkCom(const std::vector > &inputs, + std::vector > &outputs) CV_OVERRIDE + { + // does not support with bias; only 2d matmul + auto wrapper_Y = outputs[0].dynamicCast(); + auto shape_Y = shape(*(wrapper_Y->getMat())); + if (have_bias || shape_Y.size() > static_cast(2)) { + return Ptr(); + } + + std::vector vkBlobs; + if (const_B) { + vkBlobs.push_back(blobs[0]); + } + + auto wrapper_A = inputs[0].dynamicCast(); + auto shape_A = shape(*wrapper_A->getMat()); + Ptr op = (new vkcom::OpMatMul(vkBlobs, shape_A[0], shape_A[1], shape_Y[1])); + return Ptr(new VkComBackendNode(inputs, op, outputs)); + } +#endif + +private: + bool const_B; + bool const_C; + bool have_bias; + std::vector packed_B; + std::vector broadcast_C; + int real_ndims_C; + FastGemmOpt opt; +}; + +Ptr GemmLayer::create(const LayerParams& params) { + return makePtr(params); +} + +}} // namespace cv::dnn diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 24e8b3f913..671bf36957 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -1947,73 +1947,45 @@ void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const openc addLayer(layerParams, node_proto); } -// A * B + C = Y, we require that the dimension of A is [m, k], and the dimension of B is [n, k]. -// And the dim of output Y is [m, n] -void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) +void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) { - CV_Assert(node_proto.input_size() >= 2); - layerParams.type = "InnerProduct"; - int transA = layerParams.get("transA", 0); - layerParams.set("transA", transA == 1); + auto node_proto = node_proto_; + layerParams.type = "Gemm"; + CV_CheckGE(node_proto.input_size(), 2, "DNN/ONNXImporter: Gemm requires at least two inputs"); + CV_CheckLE(node_proto.input_size(), 3, "DNN/ONNXImporter: Gemm have at most three inputs."); - if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) - { - Mat inputBuf = getBlob(node_proto, 0); - - LayerParams constParams; - constParams.name = node_proto.input(0); - constParams.type = "Const"; - constParams.blobs.push_back(inputBuf); - - opencv_onnx::NodeProto proto; - proto.add_output(constParams.name); - addLayer(constParams, proto); - } - - int transB = layerParams.get("transB", 0); - int secondInpDims; - if (constBlobs.find(node_proto.input(1)) != constBlobs.end()) - { - Mat weights = getBlob(node_proto, 1); - secondInpDims = weights.dims; - - if (transA == 0) // optimized barnch, for now, we can only optimize the Gemm when transA = 0. - { - if (transB == 0) - { - transpose(weights, weights); - } - layerParams.set("transB", false); - layerParams.blobs.push_back(weights); - layerParams.set("num_output", layerParams.blobs[0].size[0]); + for (int i = 0; i < node_proto.input_size(); ++i) { + if (i == 2) { + layerParams.set("have_bias", true); + } + if (constBlobs.find(node_proto.input(i)) == constBlobs.end()) { + continue; } - else // no optimized branch, TODO! optimize when the transA==1. - { - LayerParams constParams; - constParams.name = node_proto.input(1); - constParams.type = "Const"; - constParams.blobs.push_back(weights); - opencv_onnx::NodeProto proto; - proto.add_output(constParams.name); - addLayer(constParams, proto); - layerParams.set("transB", transB == 1); + if (i == 2 && constBlobsExtraInfo.find(node_proto.input(2)) != constBlobsExtraInfo.end()) { + layerParams.set("real_ndims_C", getBlobExtraInfo(node_proto, 2).real_ndims); + } + + Mat blob = getBlob(node_proto, i); + + if (i == 0) { // A, always as inputs without prepacking + LayerParams const_A_params; + const_A_params.name = layerParams.name + "/const_A"; + const_A_params.type = "Const"; + const_A_params.blobs.push_back(blob); + + opencv_onnx::NodeProto const_node_proto; + const_node_proto.add_output(const_A_params.name); + addLayer(const_A_params, const_node_proto); + node_proto.set_input(0, const_A_params.name); + } else { // B or C + std::string const_params_name = i == 1 ? "B" : "C"; + + layerParams.blobs.push_back(blob); + layerParams.set(cv::format("const%s", const_params_name.c_str()), true); } } - else - { - layerParams.set("transB", transB == 1); - secondInpDims = outShapes[node_proto.input(1)].size(); - } - if (node_proto.input_size() == 3) - { - Mat bias = getBlob(node_proto, 2); - layerParams.blobs.push_back(bias); - } - - layerParams.set("bias_term", node_proto.input_size() == 3); - layerParams.set("is_matmul", secondInpDims > 2); addLayer(layerParams, node_proto); } diff --git a/modules/dnn/test/test_int8_layers.cpp b/modules/dnn/test/test_int8_layers.cpp index caba112516..fd42dfce48 100644 --- a/modules/dnn/test/test_int8_layers.cpp +++ b/modules/dnn/test/test_int8_layers.cpp @@ -366,7 +366,7 @@ TEST_P(Test_Int8_layers, InnerProduct) testLayer("matmul_layout", "TensorFlow", 0.035, 0.06); testLayer("tf2_dense", "TensorFlow", 0, 0); testLayer("matmul_add", "ONNX", 0.041, 0.082); - testLayer("linear", "ONNX", 0.0018, 0.0029); + testLayer("linear", "ONNX", 0.0027, 0.0046); if (backend == DNN_BACKEND_TIMVX) testLayer("constant", "ONNX", 0.00048, 0.0013); @@ -384,7 +384,7 @@ TEST_P(Test_Int8_layers, InnerProduct) testLayer("matmul_layout", "TensorFlow", 0.035, 0.095, 1, 1, false, true, false, false); testLayer("tf2_dense", "TensorFlow", 0, 0, 1, 1, false, true, false, false); testLayer("matmul_add", "ONNX", 0.041, 0.082, 1, 1, false, true, false, false); - testLayer("linear", "ONNX", 0.0022, 0.004, 1, 1, false, true, false, false); + testLayer("linear", "ONNX", 0.0027, 0.005, 1, 1, false, true, false, false); testLayer("constant", "ONNX", 0.00038, 0.0012, 1, 1, false, true, false, false); testLayer("lin_with_constant", "ONNX", 0.0011, 0.0016, 1, 1, false, true, false, false); } @@ -837,7 +837,7 @@ TEST_P(Test_Int8_nets, RCNN_ILSVRC13) if (target == DNN_TARGET_OPENCL && !ocl::Device::getDefault().isIntel()) applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL); - float l1 = 0.02, lInf = 0.042; + float l1 = 0.02, lInf = 0.047; testONNXNet("rcnn_ilsvrc13", l1, lInf); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index d695b1c202..3df0cf1924 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -2597,6 +2597,40 @@ TEST_P(Test_ONNX_layers, where_node) testONNXModels("where_layer"); } +TEST_P(Test_ONNX_layers, Conformance_Gemm_all_attributes) { + testONNXModels("test_gemm_all_attributes", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_alpha) { + testONNXModels("test_gemm_alpha", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_beta) { + testONNXModels("test_gemm_beta", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_default_matrix_bias) { + testONNXModels("test_gemm_default_matrix_bias", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_default_no_bias) { + testONNXModels("test_gemm_default_no_bias", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_default_scalar_bias) { + testONNXModels("test_gemm_default_scalar_bias", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_default_single_elem_vector_bias) { + testONNXModels("test_gemm_default_single_elem_vector_bias", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_default_vector_bias) { + testONNXModels("test_gemm_default_vector_bias", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_default_zero_bias) { + testONNXModels("test_gemm_default_zero_bias", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeA) { + testONNXModels("test_gemm_transposeA", pb, 0, 0, false, true, 2); +} +TEST_P(Test_ONNX_layers, Conformance_Gemm_transposeB) { + testONNXModels("test_gemm_transposeB", pb, 0, 0, false, true, 2); +} + INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets()); }} // namespace