Merge pull request #24694 from fengyuentau:matmul_refactor

dnn: refactor ONNX MatMul with fastGemm #24694

Done:
- [x] add backends
    - [x] CUDA
    - [x] OpenVINO
    - [x] CANN
    - [x] OpenCL
    - [x] Vulkan
- [x] add perf tests
- [x] const B case

### Benchmark

Tests are done on M1. All data is in milliseconds (ms).

| Configuration | MatMul (Prepacked) | MatMul | InnerProduct |
| - | - | - | - |
| A=[12, 197, 197], B=[12, 197, 64], trans_a=0, trans_b=0 | **0.39** | 0.41 | 1.33 |
| A=[12, 197, 64], B=[12, 64, 197], trans_a=0, trans_b=0  | **0.42** | 0.42 | 1.17 |
| A=[12, 50, 64], B=[12, 64, 50], trans_a=0, trans_b=0    | **0.13** | 0.15 | 0.33 |
| A=[12, 50, 50], B=[12, 50, 64], trans_a=0, trans_b=0    | **0.11** | 0.13 | 0.22 |
| A=[16, 197, 197], B=[16, 197, 64], trans_a=0, trans_b=0 | **0.46** | 0.54 | 1.46 |
| A=[16, 197, 64], B=[16, 64, 197], trans_a=0, trans_b=0  | **0.46** | 0.95 | 1.74 |
| A=[16, 50, 64], B=[16, 64, 50], trans_a=0, trans_b=0    | **0.18** | 0.32 | 0.43 |
| A=[16, 50, 50], B=[16, 50, 64], trans_a=0, trans_b=0    | **0.15** | 0.25 | 0.25 |

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Yuantao Feng 2023-12-20 00:36:41 +08:00 committed by GitHub
parent 465e601e10
commit fa5ed62a66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1340 additions and 107 deletions

View File

@ -1160,6 +1160,11 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<GemmLayer> create(const LayerParams& params);
};
class CV_EXPORTS MatMulLayer : public Layer {
public:
static Ptr<MatMulLayer> create(const LayerParams &params);
};
class CV_EXPORTS ExpandLayer : public Layer
{
public:

View File

@ -5,6 +5,8 @@
#include "perf_precomp.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include <numeric>
namespace opencv_test {
struct GemmParam_t {
@ -71,6 +73,18 @@ static const GemmParam_t test_gemm_configs[] = {
*/
};
static const GemmParam_t test_matmul_configs[] = {
// vision transformer cases
{ {12, 197, 197}, {12, 197, 64} },
{ {12, 197, 64 }, {12, 64, 197} },
{ {12, 50, 64}, {12, 64, 50} },
{ {12, 50, 50}, {12, 50, 64} },
{ {16, 197, 197}, {16, 197, 64} },
{ {16, 197, 64 }, {16, 64, 197} },
{ {16, 50, 64}, {16, 64, 50} },
{ {16, 50, 50}, {16, 50, 64} },
};
struct GemmParamId
{
enum {
@ -88,6 +102,21 @@ struct GemmParamId
}
};
struct MatMulParamId {
enum {
MATMUL_0 = 0,
MATMUL_LAST = sizeof(test_matmul_configs) / sizeof(test_matmul_configs[0])
};
int val_;
MatMulParamId(int val = 0) : val_(val) {}
operator int() const { return val_; }
static ::testing::internal::ParamGenerator<MatMulParamId> all() {
enum { NUM = (int)MATMUL_LAST };
MatMulParamId v_[NUM]; for (int i = 0; i < NUM; i++) { v_[i] = MatMulParamId(i); }
return ::testing::ValuesIn(v_, v_ + NUM);
}
};
static inline void PrintTo(const GemmParamId& v, std::ostream* os)
{
CV_Assert((int)v >= 0); CV_Assert((int)v < GemmParamId::GEMM_LAST);
@ -138,7 +167,7 @@ PERF_TEST_P_(Gemm, gemm)
Mat A(static_cast<int>(a_shape.size()), a_shape.data(), CV_32F);
randu(A, -1.0f, 1.0f);
Mat B(static_cast<int>(b_shape.size()), b_shape.data(), CV_32F);
randu(A, -1.0f, 1.0f);
randu(B, -1.0f, 1.0f);
LayerParams lp;
lp.type = "Gemm";
@ -197,7 +226,7 @@ PERF_TEST_P_(Gemm, innerproduct)
Mat A(static_cast<int>(a_shape.size()), a_shape.data(), CV_32F);
randu(A, -1.0f, 1.0f);
Mat B(static_cast<int>(b_shape.size()), b_shape.data(), CV_32F);
randu(A, -1.0f, 1.0f);
randu(B, -1.0f, 1.0f);
LayerParams lp;
lp.type = "InnerProduct";
@ -241,9 +270,146 @@ PERF_TEST_P_(Gemm, innerproduct)
SANITY_CHECK_NOTHING();
}
static inline void PrintTo(const MatMulParamId& v, std::ostream* os)
{
CV_Assert((int)v >= 0); CV_Assert((int)v < MatMulParamId::MATMUL_LAST);
const GemmParam_t& p = test_matmul_configs[(int)v];
auto print_shape = [os](const std::vector<int>& shape, const std::string tag) {
if (shape.empty()) {
return ;
}
*os << tag << "=[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i == shape.size() - 1) {
*os << shape[i] << "]";
break;
}
*os << shape[i] << ", ";
}
};
print_shape(p.a_shape, "A");
print_shape(p.b_shape, ", B");
print_shape(p.c_shape, ", C");
*os << ", trans_a=" << p.trans_a << ", trans_b=" << p.trans_b;
}
using MatMulTestParam_t = tuple<MatMulParamId, tuple<Backend, Target>>;
using MatMul = TestBaseWithParam<MatMulTestParam_t>;
PERF_TEST_P_(MatMul, matmul)
{
int test_id = (int)get<0>(GetParam());
ASSERT_GE(test_id, 0); ASSERT_LT(test_id, MatMulParamId::MATMUL_LAST);
const GemmParam_t& params = test_matmul_configs[test_id];
auto a_shape = params.a_shape;
auto b_shape = params.b_shape;
auto trans_a = params.trans_a;
auto trans_b = params.trans_b;
float alpha = 1.f;
float beta = 1.f;
Backend backend_id = get<0>(get<1>(GetParam()));
Target target_id = get<1>(get<1>(GetParam()));
Mat A(a_shape, CV_32F);
randu(A, -1.0f, 1.0f);
Mat B(b_shape, CV_32F);
randu(B, -1.0f, 1.0f);
LayerParams lp;
lp.type = "MatMul";
lp.name = "testLayer";
lp.set("transA", trans_a);
lp.set("transB", trans_b);
lp.set("alpha", alpha);
lp.set("beta", beta);
lp.blobs.push_back(B);
Net net;
net.addLayerToPrev(lp.name, lp.type, lp);
net.setPreferableBackend(backend_id);
net.setPreferableTarget(target_id);
// warmup
{
std::vector<std::string> input_names{"A"};
net.setInputsNames(input_names);
net.setInput(A, input_names[0]);
Mat out = net.forward();
}
TEST_CYCLE()
{
Mat res = net.forward();
}
SANITY_CHECK_NOTHING();
}
PERF_TEST_P_(MatMul, innerproduct)
{
int test_id = (int)get<0>(GetParam());
ASSERT_GE(test_id, 0); ASSERT_LT(test_id, MatMulParamId::MATMUL_LAST);
const GemmParam_t& params = test_matmul_configs[test_id];
auto a_shape = params.a_shape;
auto b_shape = params.b_shape;
Backend backend_id = get<0>(get<1>(GetParam()));
Target target_id = get<1>(get<1>(GetParam()));
Mat A(a_shape, CV_32F);
randu(A, -1.0f, 1.0f);
Mat B(b_shape, CV_32F);
randu(B, -1.0f, 1.0f);
LayerParams lp;
lp.type = "InnerProduct";
lp.name = "testLayer";
lp.set("axis", (int)(a_shape.size() - 1));
lp.set("bias_term", false);
// pre-transpose
std::vector<int> order(b_shape.size());
std::iota(order.begin(), order.end(), 0);
std::swap(order.back(), order[b_shape.size() - 2]);
Mat B_transposed;
transposeND(B, order, B_transposed);
lp.blobs.push_back(B_transposed);
lp.set("num_output", int(B_transposed.total(0, b_shape.size() - 1)));
lp.set("is_matmul", true);
Net net;
net.addLayerToPrev(lp.name, lp.type, lp);
net.setPreferableBackend(backend_id);
net.setPreferableTarget(target_id);
// warmup
{
std::vector<std::string> input_names{"A"};
net.setInputsNames(input_names);
net.setInput(A, input_names[0]);
Mat out = net.forward();
}
TEST_CYCLE()
{
Mat res = net.forward();
}
SANITY_CHECK_NOTHING();
}
INSTANTIATE_TEST_CASE_P(/**/, Gemm, Combine(
GemmParamId::all(),
dnnBackendsAndTargets(false, false) // defined in ../test/test_common.hpp
));
INSTANTIATE_TEST_CASE_P(/**/, MatMul, Combine(
MatMulParamId::all(),
dnnBackendsAndTargets(false, false) // defined in ../test/test_common.hpp
));
} // namespace

View File

@ -8,6 +8,7 @@
#include "error.hpp"
#include "stream.hpp"
#include "pointer.hpp"
#include "memory.hpp"
#include <opencv2/core.hpp>
@ -363,6 +364,145 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
);
}
/** @brief Strided batched GEMM for colummn-major matrices
*
* \f$ C_i = \alpha A_i B_i + \beta C_i \f$ for a stack of matrices A, B and C indexed by i
*
* @tparam T matrix element type (must be `half` or `float`)
*
* @param handle valid cuBLAS Handle
* @param trans_a use transposed matrix of A_i for computation
* @param trans_b use transposed matrix of B_i for computation
* @param M number of rows in C
* @param N number of columns in C
* @param K common dimension of A (or trans A) and B (or trans B)
* @param alpha scale factor for A B
* @param[in] A pointer to stack of column-major matrices A in device memory
* @param lda leading dimension of matrix A
* @param A_offsets offsets to get A slices
* @param[in] B pointer to stack of column-major matrices B in device memory
* @param ldb leading dimension of matrix B
* @param B_offsets offsets to get B slices
* @param beta scale factor for C
* @param[in,out] C pointer to stack of column-major matrices C in device memory
* @param ldc leading dimension of matrix C
* @param C_offsets offsets to get C slices
* @param batchCount number of matrices in the batch
*
* Exception Guarantee: Basic
*/
template <class T>
void gemmBatched(const Handle &handle,
bool trans_a, bool trans_b,
std::size_t M, std::size_t N, std::size_t K,
T alpha,
const DevicePtr<const T> A, std::size_t lda, std::vector<std::size_t> A_offsets,
const DevicePtr<const T> B, std::size_t ldb, std::vector<std::size_t> B_offsets,
T beta,
const DevicePtr<T> C, std::size_t ldc, std::vector<std::size_t> C_offsets,
std::size_t batchCount);
template <> inline
void gemmBatched<half>(const Handle &handle,
bool trans_a, bool trans_b,
std::size_t M, std::size_t N, std::size_t K,
half alpha,
const DevicePtr<const half> A, std::size_t lda, std::vector<std::size_t> A_offsets,
const DevicePtr<const half> B, std::size_t ldb, std::vector<std::size_t> B_offsets,
half beta,
const DevicePtr<half> C, std::size_t ldc, std::vector<std::size_t> C_offsets,
std::size_t batchCount) {
CV_Assert(handle);
const auto opa = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N,
opb = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto iM = static_cast<int>(M),
iN = static_cast<int>(N),
iK = static_cast<int>(K),
ilda = static_cast<int>(lda),
ildb = static_cast<int>(ldb),
ildc = static_cast<int>(ldc);
const auto batch_count = static_cast<int>(batchCount);
AutoBuffer<half> buffer(3 * batch_count);
auto A_slices = (half**)(buffer.data());
auto B_slices = A_slices + batch_count;
auto C_slices = B_slices + batch_count;
// collect A, B and C slices
for (int i = 0; i < batch_count; i++) {
A_slices[i] = (half*)(A.get()) + A_offsets[i];
B_slices[i] = (half*)(B.get()) + B_offsets[i];
C_slices[i] = (half*)(C.get()) + C_offsets[i];
}
const half **dev_A_slices = 0, **dev_B_slices = 0;
half **dev_C_slices = 0;
cudaMalloc((void**)&dev_A_slices, batch_count * sizeof(half*));
cudaMalloc((void**)&dev_B_slices, batch_count * sizeof(half*));
cudaMalloc((void**)&dev_C_slices, batch_count * sizeof(half*));
cudaMemcpy(dev_A_slices, A_slices, batch_count * sizeof(half*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_B_slices, B_slices, batch_count * sizeof(half*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_C_slices, C_slices, batch_count * sizeof(half*), cudaMemcpyHostToDevice);
CUDA4DNN_CHECK_CUBLAS(cublasHgemmBatched(handle.get(), opa, opb, iM, iN, iK, &alpha, dev_A_slices, ilda, dev_B_slices, ildb, &beta, dev_C_slices, ildc, batch_count));
cudaFree(dev_A_slices);
cudaFree(dev_B_slices);
cudaFree(dev_C_slices);
}
template <> inline
void gemmBatched<float>(const Handle &handle,
bool trans_a, bool trans_b,
std::size_t M, std::size_t N, std::size_t K,
float alpha,
const DevicePtr<const float> A, std::size_t lda, std::vector<std::size_t> A_offsets,
const DevicePtr<const float> B, std::size_t ldb, std::vector<std::size_t> B_offsets,
float beta,
const DevicePtr<float> C, std::size_t ldc, std::vector<std::size_t> C_offsets,
std::size_t batchCount) {
CV_Assert(handle);
const auto opa = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N,
opb = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto iM = static_cast<int>(M),
iN = static_cast<int>(N),
iK = static_cast<int>(K),
ilda = static_cast<int>(lda),
ildb = static_cast<int>(ldb),
ildc = static_cast<int>(ldc);
const auto batch_count = static_cast<int>(batchCount);
AutoBuffer<float> buffer(3 * batch_count);
auto A_slices = (float**)(buffer.data());
auto B_slices = A_slices + batch_count;
auto C_slices = B_slices + batch_count;
// collect A, B and C slices
for (int i = 0; i < batch_count; i++) {
A_slices[i] = (float*)(A.get()) + A_offsets[i];
B_slices[i] = (float*)(B.get()) + B_offsets[i];
C_slices[i] = (float*)(C.get()) + C_offsets[i];
}
const float **dev_A_slices = 0, **dev_B_slices = 0;
float **dev_C_slices = 0;
cudaMalloc((void**)&dev_A_slices, batch_count * sizeof(float*));
cudaMalloc((void**)&dev_B_slices, batch_count * sizeof(float*));
cudaMalloc((void**)&dev_C_slices, batch_count * sizeof(float*));
cudaMemcpy(dev_A_slices, A_slices, batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_B_slices, B_slices, batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(dev_C_slices, C_slices, batch_count * sizeof(float*), cudaMemcpyHostToDevice);
// cuBLAS is column-major
CUDA4DNN_CHECK_CUBLAS(cublasSgemmBatched(handle.get(), opa, opb, iM, iN, iK, &alpha, dev_A_slices, ilda, dev_B_slices, ildb, &beta, dev_C_slices, ildc, batch_count));
cudaFree(dev_A_slices);
cudaFree(dev_B_slices);
cudaFree(dev_C_slices);
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */

View File

@ -152,6 +152,31 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
batch_size);
}
/** @brief performs generalized matrix-multiplication for a strided batch of matrices
*
* Pre-conditions:
* - A, B and C must be rank three tensors with dimensions (batch, rows, cols)
* - the last two axes of \p A and \p B must meet the mathematical requirements for matrix multiplication
* - \p C must be large enough to hold the result and the matrices must not overlap in memory
*
* Exception Guarantee: Basic
*/
template <class T> inline
void gemmBatched(const cublas::Handle& handle, std::size_t batch,
T beta, TensorSpan<T> C, const std::vector<std::size_t> C_offsets, T alpha,
bool trans_a, TensorView<T> A, const std::vector<std::size_t> A_offsets,
bool trans_b, TensorView<T> B, const std::vector<std::size_t> B_offsets) {
const auto M = C.get_axis_size(-2),
N = C.get_axis_size(-1),
K = A.get_axis_size(trans_a ? -2 : -1);
const auto lda = A.get_axis_size(-1),
ldb = B.get_axis_size(-1),
ldc = N;
// collect pointers and run cublasSgemmBatched / cublasHgemmBatched
csl::cublas::gemmBatched<T>(handle, trans_b, trans_a, N, M, K, 1.f, B.get(), ldb, B_offsets, A.get(), lda, A_offsets, 0.f, C.get(), ldc, C_offsets, batch);
}
/** @brief performs element-wise addition with broadcasting
*
* Pre-conditions:

View File

@ -0,0 +1,79 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_BROADCAST_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_BROADCAST_HPP
#include "../../op_cuda.hpp"
#include "../csl/stream.hpp"
#include "../csl/cublas.hpp"
#include "../csl/tensor.hpp"
#include "../csl/tensor_ops.hpp"
#include <opencv2/core.hpp>
#include <utility>
namespace cv { namespace dnn { namespace cuda4dnn {
template <class T>
class MatMulBroadcastOp final : public CUDABackendNode {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
MatMulBroadcastOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat &B, bool _transA, bool _transB,
const std::vector<size_t> &A_offsets_, const std::vector<size_t> &B_offsets_, std::vector<size_t> &C_offsets_,
size_t batch_)
: stream(std::move(stream_)), cublasHandle(std::move(handle)), A_offsets(A_offsets_), B_offsets(B_offsets_), C_offsets(C_offsets_), batch(batch_)
{
if (!B.empty()) {
input_B_tensor = csl::makeTensorHeader<T>(B);
csl::copyMatToTensor<T>(B, input_B_tensor, stream);
}
transA = _transA;
transB = _transB;
}
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
{
CV_Assert(((inputs.size() == 2 && input_B_tensor.empty()) ||
(inputs.size() == 1 && !input_B_tensor.empty())) && outputs.size() == 1);
auto input_A_wrapper = inputs[0].dynamicCast<wrapper_type>();
auto input_A = input_A_wrapper->getView();
csl::TensorView<T> input_B;
if (input_B_tensor.empty()) {
auto input_B_wrapper = inputs[1].dynamicCast<wrapper_type>();
input_B = input_B_wrapper->getView();
} else {
input_B = csl::TensorView<T>(input_B_tensor);
}
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
csl::tensor_ops::gemmBatched<T>(cublasHandle, batch, 0.f, output, C_offsets, 1.f, transA, input_A, A_offsets, transB, input_B, B_offsets);
}
private:
csl::Stream stream;
csl::cublas::Handle cublasHandle;
csl::Tensor<T> input_B_tensor;
bool transA, transB;
std::vector<size_t> A_offsets;
std::vector<size_t> B_offsets;
std::vector<size_t> C_offsets;
size_t batch;
};
}}} /* namespace cv::dnn::cuda4dnn */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_BROADCAST_HPP */

View File

@ -102,6 +102,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(LRN, LRNLayer);
CV_DNN_REGISTER_LAYER_CLASS(InnerProduct, InnerProductLayer);
CV_DNN_REGISTER_LAYER_CLASS(Gemm, GemmLayer);
CV_DNN_REGISTER_LAYER_CLASS(MatMul, MatMulLayer);
CV_DNN_REGISTER_LAYER_CLASS(Softmax, SoftmaxLayer);
CV_DNN_REGISTER_LAYER_CLASS(SoftMax, SoftmaxLayer); // For compatibility. See https://github.com/opencv/opencv/issues/16877
CV_DNN_REGISTER_LAYER_CLASS(MVN, MVNLayer);

View File

@ -21,48 +21,76 @@
namespace cv { namespace dnn {
void fastGemmPackB(const Mat &B, std::vector<float> &packed_B, bool trans, FastGemmOpt &opt) {
CV_CheckEQ(B.dims, 2, "fastGemmPackB: input mat should be two-dimensional");
CV_CheckTypeEQ(B.type(), CV_32F, "fastGemmPackB: only float32 is supported for now");
auto B_shape = shape(B);
int K = B_shape[0], N = B_shape[1], ldb0 = N, ldb1 = 1;
int batch = total(B_shape, 0, B_shape.size() - 2),
K = B_shape[B_shape.size() - 2], N = B_shape.back(), ldb0 = N, ldb1 = 1;
if (trans) {
std::swap(K, N);
std::swap(ldb0, ldb1);
}
const auto *b = B.ptr<const char>();
int esz = B.elemSize();
#if CV_TRY_NEON
if (opt.use_neon) {
int size_packed_B = opt_NEON::fastGemmPackBSize(N, K);
packed_B.resize(size_packed_B);
opt_NEON::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize());
packed_B.resize(size_packed_B * batch);
auto *packed_b = (char*)packed_B.data();
for (int i = 0; i < batch; i++) {
opt_NEON::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, esz);
b += N * K * esz;
packed_b += size_packed_B * esz;
}
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
int size_packed_B = opt_AVX2::fastGemmPackBSize(N, K);
packed_B.resize(size_packed_B);
opt_AVX2::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize());
packed_B.resize(size_packed_B * batch);
auto *packed_b = (char*)packed_B.data();
for (int i = 0; i < batch; i++) {
opt_AVX2::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, esz);
b += N * K * esz;
packed_b += size_packed_B * esz;
}
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
int size_packed_B = opt_AVX::fastGemmPackBSize(N, K);
packed_B.resize(size_packed_B);
opt_AVX::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize());
packed_B.resize(size_packed_B * batch);
auto *packed_b = (char*)packed_B.data();
for (int i = 0; i < batch; i++) {
opt_AVX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, esz);
b += N * K * esz;
packed_b += size_packed_B * esz;
}
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
int size_packed_B = opt_LASX::fastGemmPackBSize(N, K);
packed_B.resize(size_packed_B);
opt_LASX::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize());
packed_B.resize(size_packed_B * batch);
auto *packed_b = (char*)packed_B.data();
for (int i = 0; i < batch; i++) {
opt_LASX::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, esz);
b += N * K * esz;
packed_b += size_packed_B * esz;
}
} else
#endif
{
int size_packed_B = cpu_baseline::fastGemmPackBSize(N, K);
packed_B.resize(size_packed_B);
cpu_baseline::fastGemmPackBKernel(B.ptr<const char>(), (char *)packed_B.data(), N, K, ldb0, ldb1, B.elemSize());
packed_B.resize(size_packed_B * batch);
auto *packed_b = (char*)packed_B.data();
for (int i = 0; i < batch; i++) {
cpu_baseline::fastGemmPackBKernel(b, packed_b, N, K, ldb0, ldb1, esz);
b += N * K * esz;
packed_b += size_packed_B * esz;
}
}
}
@ -131,7 +159,6 @@ void fastGemm(bool trans_a, int M, int N, int K,
void fastGemm(bool trans_a, bool trans_b, int ma, int na, int mb, int nb,
float alpha, const float *A, int lda0, int lda1, const float *B, int ldb0, int ldb1,
float beta, float *C, int ldc, FastGemmOpt &opt) {
const char *a = (const char *)A;
const char *b = (const char *)B;
char *c = (char *)C;
@ -209,54 +236,93 @@ void fastGemm(bool trans_a, bool trans_b,
beta, c, ldc, opt);
}
void fastGemmBatched(bool trans_a, bool trans_b,
float alpha, const Mat &A, const Mat &B,
float beta, Mat &C, FastGemmOpt &opt) {
CV_CheckTypeEQ(A.type(), B.type(), "DNN/fastGemmBatched: A and B should have the same type");
CV_CheckTypeEQ(B.type(), C.type(), "DNN/fastGemmBatched: B and C should have the same type");
CV_CheckTypeEQ(A.type(), CV_32F, "DNN/fastGemmBatched: only support float32 for now");
void fastGemmBatch(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const float *A, int lda0, int lda1,
const float *B, int ldb0, int ldb1, float beta, float *C, int ldc, FastGemmOpt &opt) {
const char *a = (const char *)A;
const char *b = (const char *)B;
char *c = (char *)C;
const auto shape_a = shape(A);
size_t dims_A = shape_a.size();
CV_CheckGE(dims_A, static_cast<size_t>(2), "DNN/fastGemmBatched: A must be n-dimensional (n >= 2)");
const auto shape_b = shape(B);
CV_CheckEQ(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatched: B must be 2-dimensional");
const auto shape_c = shape(C);
size_t dims_C = shape_c.size();
CV_CheckGE(dims_C, static_cast<size_t>(2), "DNN/fastGemmBatched: C must be n-dimensional (n >= 2)");
if (trans_a) {
int ma = shape_a[dims_A - 2], na = shape_a[dims_A - 1];
int mb = shape_b[0], nb = shape_b[1];
int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1];
const float *a = A.ptr<const float>();
const float *b = B.ptr<const float>();
float *c = C.ptr<float>();
int batches = std::accumulate(shape_a.begin(), shape_a.end() - 2, 1, std::multiplies<int>());
int step_a = ma * na, step_c = na * nb;
for (int i = 0; i < batches; i++) {
fastGemm(true, trans_b, ma, na, mb, nb,
alpha, a + i * step_a, lda0, lda1, b, ldb0, ldb1,
beta, c + i * step_c, ldc, opt);
}
} else {
int ma = std::accumulate(shape_a.begin(), shape_a.end() - 1, 1, std::multiplies<int>()),
na = shape_a[dims_A - 1];
int mb = shape_b[0], nb = shape_b[1];
int lda0 = na, lda1 = 1, ldb0 = nb, ldb1 = 1, ldc = shape_c[1];
const float *a = A.ptr<const float>();
const float *b = B.ptr<const float>();
float *c = C.ptr<float>();
fastGemm(false, trans_b, ma, na, mb, nb,
alpha, a, lda0, lda1, b, ldb0, ldb1,
beta, c, ldc, opt);
#if CV_TRY_NEON
if (opt.use_neon) {
opt_NEON::fastGemmBatchKernel(batch, A_offsets, B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, ldb0, ldb1, beta, c, ldc, sizeof(float));
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
opt_AVX2::fastGemmBatchKernel(batch, A_offsets, B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, ldb0, ldb1, beta, c, ldc, sizeof(float));
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
opt_AVX::fastGemmBatchKernel(batch, A_offsets, B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, ldb0, ldb1, beta, c, ldc, sizeof(float));
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
opt_LASX::fastGemmBatchKernel(batch, A_offsets, B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, ldb0, ldb1, beta, c, ldc, sizeof(float));
} else
#endif
{
cpu_baseline::fastGemmBatchKernel(batch, A_offsets, B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, ldb0, ldb1, beta, c, ldc, sizeof(float));
}
}
void fastGemmBatch(size_t batch, const size_t *A_offsets, const size_t *packed_B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const float *A, int lda0, int lda1,
const float *packed_B, float beta, float *C, int ldc, FastGemmOpt &opt) {
const char *a = (const char *)A;
const char *b = (const char *)packed_B;
char *c = (char *)C;
#if CV_TRY_NEON
if (opt.use_neon) {
opt_NEON::fastGemmBatchKernel(batch, A_offsets, packed_B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, beta, c, ldc, sizeof(float));
} else
#endif
#if CV_TRY_AVX2
if (opt.use_avx2) {
opt_AVX2::fastGemmBatchKernel(batch, A_offsets, packed_B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, beta, c, ldc, sizeof(float));
} else
#endif
#if CV_TRY_AVX
if (opt.use_avx) {
opt_AVX::fastGemmBatchKernel(batch, A_offsets, packed_B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, beta, c, ldc, sizeof(float));
} else
#endif
#if CV_TRY_LASX
if (opt.use_lasx) {
opt_LASX::fastGemmBatchKernel(batch, A_offsets, packed_B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, beta, c, ldc, sizeof(float));
} else
#endif
{
cpu_baseline::fastGemmBatchKernel(batch, A_offsets, packed_B_offsets, C_offsets, M, N, K, alpha, a, lda0, lda1, b, beta, c, ldc, sizeof(float));
}
}
void fastGemmBatch(bool trans_a, bool trans_b,
float alpha, const Mat &A, const Mat &B,
float beta, Mat &C, FastGemmOpt &opt) {
CV_CheckTypeEQ(A.type(), B.type(), "DNN/fastGemmBatch: A and B should have the same type");
CV_CheckTypeEQ(B.type(), C.type(), "DNN/fastGemmBatch: B and C should have the same type");
CV_CheckTypeEQ(A.type(), CV_32F, "DNN/fastGemmBatch: only support float32 for now");
const auto shape_a = shape(A);
const auto shape_b = shape(B);
const auto shape_c = shape(C);
CV_CheckGE(shape_a.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: A must be n-dimensional (n >= 2)");
CV_CheckEQ(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)");
const float *a = A.ptr<const float>();
const float *b = B.ptr<const float>();
float *c = C.ptr<float>();
MatMulHelper helper;
helper.compute(trans_a, trans_b, shape_a, shape_b, shape_c);
fastGemmBatch(helper.batch, helper.A_offsets.data(), helper.B_offsets.data(), helper.C_offsets.data(),
helper.M, helper.N, helper.K, alpha, a, helper.lda0, helper.lda1, b, helper.ldb0,
helper.ldb1, beta, c, helper.ldc, opt);
}
}} // cv::dnn

View File

@ -42,6 +42,112 @@ struct FastGemmOpt {
}
};
struct MatMulHelper {
std::vector<size_t> A_offsets;
std::vector<size_t> B_offsets;
std::vector<size_t> packed_B_offsets;
std::vector<size_t> C_offsets;
std::vector<size_t> A_rows;
std::vector<size_t> B_rows;
std::vector<size_t> C_rows;
size_t batch;
int lda0, lda1;
int ldb0, ldb1;
int ldc;
int M, N, K;
MatMulHelper() {
A_offsets = {0};
B_offsets = {0};
packed_B_offsets = {0};
C_offsets = {0};
A_rows = {0};
B_rows = {0};
C_rows = {0};
batch = 0;
}
bool empty() const {
return batch == 0;
}
void compute(bool trans_a, bool trans_b, MatShape A_shape, MatShape B_shape, MatShape C_shape) {
auto A_ndims = A_shape.size(), B_ndims = B_shape.size(), C_ndims = C_shape.size();
int ma = A_shape[A_ndims - 2], na = A_shape.back();
int mb = B_shape[B_ndims - 2], nb = B_shape.back();
lda0 = na, lda1 = 1;
ldb0 = nb, ldb1 = 1;
ldc = C_shape.back();
M = trans_a ? na : ma;
N = trans_b ? mb : nb;
K = trans_a ? ma : na;
if (trans_a) {
std::swap(lda0, lda1);
}
if (trans_b) {
std::swap(ldb0, ldb1);
}
// compute offsets
auto batch_ndims = C_ndims - 2;
batch = total(C_shape, 0, batch_ndims);
A_offsets.resize(batch, 0);
B_offsets.resize(batch, 0);
C_offsets.resize(batch, 0);
A_rows.resize(batch, 0);
B_rows.resize(batch, 0);
C_rows.resize(batch, 0);
// build C_offsets
size_t C_step = total(C_shape, C_ndims - 2);
MatShape A_broadcast_shape(C_ndims, 1);
std::memcpy(A_broadcast_shape.data() + (C_ndims - A_ndims), A_shape.data(), A_ndims * sizeof(int));
MatShape B_broadcast_shape(C_shape.size(), 1);
std::memcpy(B_broadcast_shape.data() + (C_ndims - B_ndims), B_shape.data(), B_shape.size() * sizeof(int));
std::vector<size_t> A_steps(C_ndims, 1), B_steps(C_ndims, 1);
for (int i = C_ndims - 2; i >= 0; i--) {
A_steps[i] = A_steps[i + 1] * A_broadcast_shape[i + 1];
B_steps[i] = B_steps[i + 1] * B_broadcast_shape[i + 1];
}
size_t t, idx;
for (size_t i = 0; i < batch; i++) {
C_offsets[i] = i * C_step;
C_rows[i] = i;
size_t A_offset = 0, B_offset = 0;
t = i;
for (int j = batch_ndims - 1; j >= 0; j--) {
idx = t / C_shape[j];
int idx_offset = (int)(t - idx * C_shape[j]);
A_offset += A_broadcast_shape[j] == 1 ? 0 : idx_offset * A_steps[j];
B_offset += B_broadcast_shape[j] == 1 ? 0 : idx_offset * B_steps[j];
t = idx;
}
A_offsets[i] = A_offset;
B_offsets[i] = B_offset;
A_rows[i] = A_offset / (M * K);
B_rows[i] = B_offset / (N * K);
}
}
// only run after compute
void updatePackedBOffsets(size_t packed_B_size) {
size_t packed_B_inner_size = packed_B_size / batch;
packed_B_offsets.resize(B_offsets.size());
for (size_t i = 0; i < packed_B_offsets.size(); i++) {
packed_B_offsets[i] = (B_offsets[i] / (N * K)) * packed_B_inner_size;
}
}
};
void fastGemmPackB(const Mat &m, std::vector<float> &packed_B, bool trans, FastGemmOpt &opt);
void fastGemm(bool trans_a, int M, int N, int K,
@ -55,10 +161,14 @@ void fastGemm(bool trans_a, bool trans_b,
float alpha, const Mat &A, const Mat &B,
float beta, Mat &C, FastGemmOpt &opt);
// FIXME: B needs to 2d for now. Support nd (n>=2) B in the future.
void fastGemmBatched(bool trans_a, bool trans_b,
float alpha, const Mat &A, const Mat &B,
float beta, Mat &C, FastGemmOpt &opt);
void fastGemmBatch(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const float *A, int lda0, int lda1,
const float *B, int ldb0, int ldb1, float beta, float *C, int ldc, FastGemmOpt &opt);
void fastGemmBatch(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const float *A, int lda0, int lda1,
const float *packed_B, float beta, float *C, int ldc, FastGemmOpt &opt);
void fastGemmBatch(bool trans_a, bool trans_b, float alpha, const Mat &A,
const Mat &B, float beta, Mat &C, FastGemmOpt &opt);
}} // cv::dnn

View File

@ -88,6 +88,13 @@ void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz);
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1, float beta, char *C, int ldc, int esz);
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz);
FAST_GEMM_IMPLEMENT_PACK(8, _f32, float, float)
FAST_GEMM_IMPLEMENT_PACK(12, _f32, float, float)
@ -300,6 +307,153 @@ void fastGemmKernel(int M, int N, int K,
parallel_for_(Range(0, total), fn, nstripes);
}
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1, float beta, char *C, int ldc, int esz) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
GEMM_NR = FAST_GEMM_F32_NR;
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K);
size_t buff_size = KC * (MC + NC) * esz;
bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF;
int m_tiles = (M + MC - 1) / MC;
int n_tiles = (N + NC - 1) / NC;
int total_tiles = m_tiles * n_tiles;
auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
char* packed_b = packed_a + KC * MC * esz;
int start = r.start;
int end = r.end;
for (int tile_idx = start; tile_idx < end; tile_idx++) {
const int batch_index = static_cast<int>(tile_idx / total_tiles);
const int m_tiles_index = static_cast<int>((tile_idx - batch_index * total_tiles) / n_tiles);
const int n_tiles_index = static_cast<int>(tile_idx % n_tiles);
int i0 = m_tiles_index * MC;
int j0 = n_tiles_index * NC;
int mc = M - i0 < MC ? M - i0 : MC;
int nc = N - j0 < NC ? N - j0 : NC;
int ldc_block = ldc;
const char *a_block = A + A_offsets[batch_index] * esz;
const char *b_block = B + B_offsets[batch_index] * esz;
char* c_block = C + C_offsets[batch_index] * esz + (i0 * ldc + j0) * esz;
if (beta == 0.f) {
for(int i = 0; i < mc; i++)
memset(c_block + i * ldc_block * esz, 0, nc * esz);
} else if (beta != 1.f) {
for(int i = 0; i < mc; i++) {
float* c_i = (float*)c_block + i * ldc_block;
for(int j = 0; j < nc; j++)
c_i[j] *= beta;
}
}
for(int k0 = 0; k0 < K; k0 += KC)
{
int kc = K - k0 < KC ? K - k0 : KC;
// pack a
fast_gemm_pack8_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
// pack b
fast_gemm_pack12_f32(nc, kc, b_block + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b);
// run kernel
fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz);
}
}
if (!use_stackbuff) {
free(packed_a);
}
};
int total = batch * total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
}
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
GEMM_NR = FAST_GEMM_F32_NR;
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K);
size_t buff_size = KC * MC * esz;
bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF;
int m_tiles = (M + MC - 1) / MC;
int n_tiles = (N + NC - 1) / NC;
int total_tiles = m_tiles * n_tiles;
auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
const char *packed_b = packed_B;
int start = r.start;
int end = r.end;
for (int tile_idx = start; tile_idx < end; tile_idx++) {
const int batch_index = static_cast<int>(tile_idx / total_tiles);
const int m_tiles_index = static_cast<int>((tile_idx - batch_index * total_tiles) / n_tiles);
const int n_tiles_index = static_cast<int>(tile_idx % n_tiles);
int i0 = m_tiles_index * MC;
int j0 = n_tiles_index * NC;
int mc = M - i0 < MC ? M - i0 : MC;
int nc = N - j0 < NC ? N - j0 : NC;
int ldc_block = ldc;
const char *a_block = A + A_offsets[batch_index] * esz;
packed_b = packed_B + B_offsets[batch_index] * esz + j0 * K * esz;
char* c_block = C + C_offsets[batch_index] * esz + (i0 * ldc + j0) * esz;
if (beta == 0.f) {
for(int i = 0; i < mc; i++)
memset(c_block + i * ldc_block * esz, 0, nc * esz);
} else if (beta != 1.f) {
for(int i = 0; i < mc; i++) {
float* c_i = (float*)c_block + i * ldc_block;
for(int j = 0; j < nc; j++)
c_i[j] *= beta;
}
}
int _nc = static_cast<int>((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz;
for(int k0 = 0; k0 < K; k0 += KC)
{
int kc = K - k0 < KC ? K - k0 : KC;
// pack a
fast_gemm_pack8_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
// run kernel
fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz);
packed_b += _nc * kc;
}
}
if (!use_stackbuff) {
free(packed_a);
}
};
int total = batch * total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
}
}}} // cv::dnn::cpu_baseline
#undef FAST_GEMM_STORAGE

View File

@ -22,8 +22,8 @@
#define FAST_GEMM_F32_MC 48
#define FAST_GEMM_F32_NC 128
#else // CV_NEON_AARCH64, SIMD128
#define FAST_GEMM_F32_MC 64
#define FAST_GEMM_F32_NC 240
#define FAST_GEMM_F32_MC 144
#define FAST_GEMM_F32_NC 72
#endif
#if CV_AVX
@ -127,6 +127,13 @@ void fastGemmKernel(int M, int N, int K,
float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz);
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1, float beta, char *C, int ldc, int esz);
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz);
#ifndef CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY
/*
@ -721,6 +728,177 @@ void fastGemmKernel(int M, int N, int K,
parallel_for_(Range(0, total), fn, nstripes);
}
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *B, int ldb0, int ldb1, float beta, char *C, int ldc, int esz) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
GEMM_NR = FAST_GEMM_F32_NR;
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K);
size_t buff_size = KC * (MC + NC) * esz;
bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF;
int m_tiles = (M + MC - 1) / MC;
int n_tiles = (N + NC - 1) / NC;
int total_tiles = m_tiles * n_tiles;
auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
char* packed_b = packed_a + KC * MC * esz;
int start = r.start;
int end = r.end;
for (int tile_idx = start; tile_idx < end; tile_idx++) {
const int batch_index = static_cast<int>(tile_idx / total_tiles);
const int m_tiles_index = static_cast<int>((tile_idx - batch_index * total_tiles) / n_tiles);
const int n_tiles_index = static_cast<int>(tile_idx % n_tiles);
int i0 = m_tiles_index * MC;
int j0 = n_tiles_index * NC;
int mc = M - i0 < MC ? M - i0 : MC;
int nc = N - j0 < NC ? N - j0 : NC;
int ldc_block = ldc;
const char *a_block = A + A_offsets[batch_index] * esz;
const char *b_block = B + B_offsets[batch_index] * esz;
char* c_block = C + C_offsets[batch_index] * esz + (i0 * ldc + j0) * esz;
if (beta == 0.f) {
for(int i = 0; i < mc; i++)
memset(c_block + i * ldc_block * esz, 0, nc * esz);
} else if (beta != 1.f) {
for(int i = 0; i < mc; i++) {
float* c_i = (float*)c_block + i * ldc_block;
for(int j = 0; j < nc; j++)
c_i[j] *= beta;
}
}
for(int k0 = 0; k0 < K; k0 += KC)
{
int kc = K - k0 < KC ? K - k0 : KC;
// pack a
#if CV_NEON && CV_NEON_AARCH64
fast_gemm_pack8_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#elif CV_AVX
fast_gemm_pack12_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#elif CV_LASX
fast_gemm_pack12_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#elif CV_SIMD128
fast_gemm_pack8_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#endif
// pack b
#if CV_NEON && CV_NEON_AARCH64
fast_gemm_pack12_f32(nc, kc, b_block + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b);
#elif CV_AVX
fast_gemm_pack8_f32(nc, kc, b_block + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b);
#elif CV_LASX
fast_gemm_pack16_f32(nc, kc, b_block + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b);
#elif CV_SIMD128
fast_gemm_pack12_f32(nc, kc, b_block + (k0 * ldb0 + j0 * ldb1) * esz, ldb1, ldb0, packed_b);
#endif
// run kernel
fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz);
}
}
if (!use_stackbuff) {
free(packed_a);
}
};
int total = batch * total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
}
void fastGemmBatchKernel(size_t batch, const size_t *A_offsets, const size_t *B_offsets, const size_t *C_offsets,
int M, int N, int K, float alpha, const char *A, int lda0, int lda1,
const char *packed_B, float beta, char *C, int ldc, int esz) {
int GEMM_MC = FAST_GEMM_F32_MC,
GEMM_NC = FAST_GEMM_F32_NC,
GEMM_MR = FAST_GEMM_F32_MR,
GEMM_NR = FAST_GEMM_F32_NR;
int MC = (((GEMM_MC < M ? GEMM_MC : M) + GEMM_MR - 1) / GEMM_MR) * GEMM_MR;
int NC = (((GEMM_NC < N ? GEMM_NC : N) + GEMM_NR - 1) / GEMM_NR) * GEMM_NR;
int KC = std::min(FAST_GEMM_F32_PACKED_STRIDE_K, K);
size_t buff_size = KC * MC * esz;
bool use_stackbuff = buff_size <= FAST_GEMM_MAX_STACKBUF;
int m_tiles = (M + MC - 1) / MC;
int n_tiles = (N + NC - 1) / NC;
int total_tiles = m_tiles * n_tiles;
auto fn = [&](const Range &r) {
char* packed_a = (char*)(use_stackbuff ? alloca(buff_size) : malloc(buff_size));
const char *packed_b = packed_B;
int start = r.start;
int end = r.end;
for (int tile_idx = start; tile_idx < end; tile_idx++) {
const int batch_index = static_cast<int>(tile_idx / total_tiles);
const int m_tiles_index = static_cast<int>((tile_idx - batch_index * total_tiles) / n_tiles);
const int n_tiles_index = static_cast<int>(tile_idx % n_tiles);
int i0 = m_tiles_index * MC;
int j0 = n_tiles_index * NC;
int mc = M - i0 < MC ? M - i0 : MC;
int nc = N - j0 < NC ? N - j0 : NC;
int ldc_block = ldc;
const char *a_block = A + A_offsets[batch_index] * esz;
packed_b = packed_B + B_offsets[batch_index] * esz + j0 * K * esz;
char* c_block = C + C_offsets[batch_index] * esz + (i0 * ldc + j0) * esz;
if (beta == 0.f) {
for(int i = 0; i < mc; i++)
memset(c_block + i * ldc_block * esz, 0, nc * esz);
} else if (beta != 1.f) {
for(int i = 0; i < mc; i++) {
float* c_i = (float*)c_block + i * ldc_block;
for(int j = 0; j < nc; j++)
c_i[j] *= beta;
}
}
int _nc = static_cast<int>((nc + GEMM_NR - 1) / GEMM_NR) * GEMM_NR * esz;
for(int k0 = 0; k0 < K; k0 += KC)
{
int kc = K - k0 < KC ? K - k0 : KC;
// pack a
#if CV_NEON && CV_NEON_AARCH64
fast_gemm_pack8_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#elif CV_AVX
fast_gemm_pack12_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#elif CV_LASX
fast_gemm_pack12_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#elif CV_SIMD128
fast_gemm_pack8_f32(mc, kc, a_block + (i0 * lda0 + k0 * lda1) * esz, lda0, lda1, packed_a);
#endif
// run kernel
fast_gemm_macro_kernel(mc, nc, kc, packed_a, packed_b, alpha, c_block, ldc_block, esz);
packed_b += _nc * kc;
}
}
if (!use_stackbuff) {
free(packed_a);
}
};
int total = batch * total_tiles;
int cost_per_thread = static_cast<int>((K / KC) * (MC / GEMM_MR) * (NC / GEMM_NR));
double nstripes = (size_t)total * cost_per_thread * (1 / 1024.0);
parallel_for_(Range(0, total), fn, nstripes);
}
#endif // CV_CPU_OPTIMIZATION_DECLARATIONS_ONLY
CV_CPU_OPTIMIZATION_NAMESPACE_END

View File

@ -211,7 +211,7 @@ public:
CV_CheckGT(packed_B.size(), static_cast<size_t>(0), "DNN/Gemm: constant B is not pre-packed");
fastGemm(trans_a, M, N, K, alpha, A.ptr<const float>(), na, packed_B.data(), 1.f, Y.ptr<float>(), N, opt);
} else {
fastGemmBatched(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt);
fastGemmBatch(trans_a, trans_b, alpha, A, inputs[1], 1.f, Y, opt);
}
}

View File

@ -0,0 +1,326 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include <opencv2/dnn/shape_utils.hpp>
#include "cpu_kernels/fast_gemm.hpp"
// OpenVINO backend
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
// Vulkan backend
#include "../op_vkcom.hpp"
// CUDA backend
#ifdef HAVE_CUDA
#include "../cuda4dnn/primitives/matmul_broadcast.hpp"
using namespace cv::dnn::cuda4dnn;
#endif
// CANN backend
#include "../op_cann.hpp"
namespace cv { namespace dnn {
class MatMulLayerImpl CV_FINAL : public MatMulLayer {
public:
MatMulLayerImpl(const LayerParams& params) {
setParamsFrom(params);
trans_a = params.get<bool>("transA", false);
trans_b = params.get<bool>("transB", false);
alpha = params.get<float>("alpha", 1.f);
beta = params.get<float>("beta", 1.f);
}
virtual bool supportBackend(int backendId) CV_OVERRIDE {
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH ||
(backendId == DNN_BACKEND_VKCOM && haveVulkan() && !trans_a && !trans_b) ||
backendId == DNN_BACKEND_CUDA ||
backendId == DNN_BACKEND_CANN;
}
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE {
CV_CheckGE(inputs.size(), static_cast<size_t>(1), "DNN/MatMul: one varible input at least");
CV_CheckLE(inputs.size(), static_cast<size_t>(2), "DNN/MatMul: two variable inputs at most");
const auto shape_A = inputs[0], shape_B = blobs.empty() ? inputs[1] : shape(blobs[0]);
CV_CheckGE(shape_A.size(), static_cast<size_t>(2), "DNN/MatMul: invalid shape of input A");
CV_CheckGE(shape_B.size(), static_cast<size_t>(2), "DNN/MatMul: invalid shape of input B");
// Check legal matrix multiplication
int mA = shape_A[shape_A.size() - 2], nA = shape_A.back();
int mB = shape_B[shape_B.size() - 2], nB = shape_B.back();
int M = trans_a ? nA : mA;
int N = trans_b ? mB : nB;
int K_A = trans_a ? mA : nA;
int K_B = trans_b ? nB : mB;
CV_CheckEQ(K_A, K_B, "DNN/MatMul: invalid dimension K");
// Check legal broadcast. It is legal for sure if A and B are 2d, or one of them is 2d.
MatShape common_shape;
if (shape_A.size() != 2 || shape_B.size() != 2) {
const auto &shape_more_dims = shape_A.size() > shape_B.size() ? shape_A : shape_B;
const auto &shape_less_dims = shape_A.size() > shape_B.size() ? shape_B : shape_A;
size_t diff_dims = shape_more_dims.size() - shape_less_dims.size();
common_shape = shape_more_dims;
for (size_t i = 0; i < shape_less_dims.size() - 2; i++) {
const auto dl = shape_less_dims[i], dm = shape_more_dims[i + diff_dims];
if (dl != 1 && dm != 1 && dl != dm) {
CV_Error(Error::StsBadSize, format("DNN/MatMul: invalid shape for broadcasting, shape_A[%zu]=%d, shape_B[%zu]=%d\n", i, shape_less_dims[i], i, shape_more_dims[i + diff_dims]));
}
if (dm == 1) {
common_shape[i + diff_dims] = dl;
}
}
common_shape[common_shape.size() - 2] = M;
common_shape[common_shape.size() - 1] = N;
} else {
common_shape.resize(2);
common_shape[0] = M;
common_shape[1] = N;
}
outputs.assign(1, common_shape);
return false;
}
virtual void finalize(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr) CV_OVERRIDE {
opt.init();
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
const auto A_shape = shape(inputs[0]),
B_shape = blobs.empty() ? shape(inputs[1]) : shape(blobs[0]),
C_shape = shape(outputs[0]);
helper.compute(trans_a, trans_b, A_shape, B_shape, C_shape);
if (!blobs.empty()) {
fastGemmPackB(blobs[0], packed_input_B, trans_b, opt);
helper.updatePackedBOffsets(packed_input_B.size());
}
}
// works like Y = numpy.matmul(A, B)
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE {
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
forward_ocl(inputs_arr, outputs_arr, internals_arr))
if (inputs_arr.depth() == CV_16S)
{
forward_fallback(inputs_arr, outputs_arr, internals_arr);
return;
}
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
const auto &A = inputs[0];
auto &Y = outputs[0];
const auto *a = A.ptr<const float>();
auto *y = Y.ptr<float>();
std::memset(y, 0, Y.total() * sizeof(float));
if (blobs.empty()) {
const auto &B = inputs[1];
const auto *b = B.ptr<const float>();
fastGemmBatch(helper.batch, helper.A_offsets.data(), helper.B_offsets.data(), helper.C_offsets.data(),
helper.M, helper.N, helper.K, alpha, a, helper.lda0, helper.lda1,
b, helper.ldb0, helper.ldb1, beta, y, helper.ldc, opt);
} else {
fastGemmBatch(helper.batch, helper.A_offsets.data(), helper.packed_B_offsets.data(), helper.C_offsets.data(),
helper.M, helper.N, helper.K, alpha, a, helper.lda0, helper.lda1,
packed_input_B.data(), beta, y, helper.ldc, opt);
}
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, InputArrayOfArrays internals) {
std::vector<UMat> inputs;
std::vector<UMat> outputs;
bool use_half = (inputs_arr.depth() == CV_16S);
inputs_arr.getUMatVector(inputs);
outputs_arr.getUMatVector(outputs);
const auto &input_A = inputs[0];
UMat input_B;
if (blobs.empty()) {
input_B = inputs[1];
} else {
blobs[0].copyTo(input_B);
}
auto &output = outputs[0];
int M = static_cast<int>(helper.M),
N = static_cast<int>(helper.N),
K = static_cast<int>(helper.K),
batch = static_cast<int>(helper.batch);
int batch_A = total(shape(input_A)) / (M * K),
batch_B = total(shape(input_B)) / (N * K);
MatShape new_shape_A{batch_A, M * K}, new_shape_B{batch_B, N * K}, new_shape_output{batch, M * N};
const auto input_A_2d = input_A.reshape(1, new_shape_A.size(), &new_shape_A[0]),
input_B_2d = input_B.reshape(1, new_shape_B.size(), &new_shape_B[0]);
auto output_2d = output.reshape(1, new_shape_output.size(), &new_shape_output[0]);
UMat A, B, C, A_fp32, B_fp32, C_fp32;
for (int i = 0; i < batch; i++) {
A = input_A_2d.row(helper.A_rows[i]).reshape(1, trans_a ? K : M);
B = input_B_2d.row(helper.B_rows[i]).reshape(1, trans_b ? K : N);
C = output_2d.row(helper.C_rows[i]).reshape(1, M);
if (trans_a) {
A = A.t();
}
if (trans_b) {
B = B.t();
}
if (use_half) {
convertFp16(A, A_fp32);
convertFp16(B, B_fp32);
convertFp16(C, C_fp32);
} else {
A_fp32 = A;
B_fp32 = B;
C_fp32 = C;
}
cv::gemm(A_fp32, B_fp32, 1.f, noArray(), 0.f, C_fp32);
if (use_half) {
convertFp16(A_fp32, A);
convertFp16(B_fp32, B);
convertFp16(C_fp32, C);
}
}
return true;
}
#endif // HAVE_OPENCL
#ifdef HAVE_DNN_NGRAPH
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE {
auto& input_A_node = nodes[0].dynamicCast<InfEngineNgraphNode>()->node;
std::shared_ptr<ngraph::Node> matmul;
if (nodes.size() == 2) {
auto &input_B_node = nodes[1].dynamicCast<InfEngineNgraphNode>()->node;
matmul = std::make_shared<ngraph::op::MatMul>(input_A_node, input_B_node, trans_a, trans_b);
} else {
auto input_B_shape = getShape<size_t>(blobs[0]);
auto input_B_node = std::make_shared<ngraph::op::Constant>(ngraph::element::f32, input_B_shape, blobs[0].data);
matmul = std::make_shared<ngraph::op::MatMul>(input_A_node, input_B_node, trans_a, trans_b);
}
return Ptr<BackendNode>(new InfEngineNgraphNode(matmul));
}
#endif // HAVE_DNN_NGRAPH
#ifdef HAVE_VULKAN
virtual Ptr<BackendNode> initVkCom(const std::vector<Ptr<BackendWrapper> > &inputs,
std::vector<Ptr<BackendWrapper> > &outputs) CV_OVERRIDE {
auto input_A_wrapper = inputs[0].dynamicCast<VkComBackendWrapper>();
auto output_wrapper = outputs[0].dynamicCast<VkComBackendWrapper>();
const auto input_A_shape = shape(*input_A_wrapper->getMat());
const auto output_shape = shape(*output_wrapper->getMat());
if (output_shape.size() != 2) {
return Ptr<BackendNode>();
}
std::vector<Mat> constants;
if (!blobs.empty()) {
constants.push_back(blobs[0]);
}
Ptr<vkcom::OpBase> op = new vkcom::OpMatMul(constants, input_A_shape[0], input_A_shape[1], output_shape[1]);
return Ptr<BackendNode>(new VkComBackendNode(inputs, op, outputs));
}
#endif
#ifdef HAVE_CUDA
Ptr<BackendNode> initCUDA(void *context_,
const std::vector<Ptr<BackendWrapper>>& inputs,
const std::vector<Ptr<BackendWrapper>>& outputs) override {
auto context = reinterpret_cast<csl::CSLContext*>(context_);
auto input_B = blobs.empty() ? Mat() : blobs[0];
CV_CheckFalse(helper.empty(), "DNN/MatMul/CUDA: MatMulHelper is not initialized");
return make_cuda_node<cuda4dnn::MatMulBroadcastOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), input_B, trans_a, trans_b, helper.A_offsets, helper.B_offsets, helper.C_offsets, helper.batch);
}
#endif // HAVE_CUDA
#ifdef HAVE_CANN
virtual Ptr<BackendNode> initCann(const std::vector<Ptr<BackendWrapper> > &inputs,
const std::vector<Ptr<BackendWrapper> > &outputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE {
auto input_A_wrapper = inputs[0].dynamicCast<CannBackendWrapper>();
auto input_A_desc = input_A_wrapper->getTensorDesc();
auto input_A_node = nodes[0].dynamicCast<CannBackendNode>()->getOp();
auto op = std::make_shared<ge::op::BatchMatMul>(name);
// set attributes
op->set_attr_adj_x1(trans_a);
op->set_attr_adj_x2(trans_b);
// set inputs
// set inputs : x1
op->set_input_x1_by_name(*input_A_node, input_A_wrapper->name.c_str());
op->update_input_desc_x1(*input_A_desc);
// set inputs : x2
if (blobs.empty()) { // varaible input B
auto input_B_wrapper = inputs[1].dynamicCast<CannBackendWrapper>();
auto input_B_desc = input_B_wrapper->getTensorDesc();
auto input_B_node = nodes[1].dynamicCast<CannBackendNode>()->getOp();
op->set_input_x2_by_name(*input_B_node, "y");
op->update_input_desc_x2(*input_B_desc);
} else { // constant input B
auto B = blobs[0];
auto const_B_node = std::make_shared<CannConstOp>(B.data, B.type(), shape(B), cv::format("%s_B", name.c_str()));
op->set_input_x2_by_name(*(const_B_node->getOp()), "y");
op->update_input_desc_x2(*(const_B_node->getTensorDesc()));
}
// set outputs
auto output_desc = std::make_shared<ge::TensorDesc>(ge::Shape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
op->update_output_desc_y(*output_desc);
return Ptr<BackendNode>(new CannBackendNode(op));
}
#endif // HAVE_CANN
private:
bool trans_a;
bool trans_b;
float alpha;
float beta;
std::vector<float> packed_input_B;
FastGemmOpt opt;
MatMulHelper helper;
};
Ptr<MatMulLayer> MatMulLayer::create(const LayerParams& params)
{
return makePtr<MatMulLayerImpl>(params);
}
}} // cv::dnn

View File

@ -1957,50 +1957,33 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
addLayer(layerParams, node_proto);
}
void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{
opencv_onnx::NodeProto node_proto = node_proto_;
CV_Assert(node_proto.input_size() == 2);
layerParams.type = "InnerProduct";
layerParams.set("bias_term", false);
int firstInpDims, secondInpDims;
void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_) {
auto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 2, "ONNXImporter/MatMul: two inputs required");
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{
Mat blob = getBlob(node_proto, 0);
firstInpDims = blob.dims;
LayerParams constParams;
constParams.name = layerParams.name + "/const_0";
constParams.type = "Const";
constParams.blobs.push_back(blob);
for (int i = 0; i < node_proto.input_size(); i++) {
if (constBlobs.find(node_proto.input(i)) == constBlobs.end()) {
continue;
}
opencv_onnx::NodeProto tmpProto;
tmpProto.add_output(constParams.name);
addLayer(constParams, tmpProto);
Mat blob = getBlob(node_proto, i);
node_proto.set_input(0, constParams.name);
if (i == 1) {
layerParams.blobs.push_back(blob);
} else {
LayerParams const_params;
const_params.name = node_proto.input(i);
const_params.type = "Const";
const_params.blobs.push_back(blob);
opencv_onnx::NodeProto const_node_proto;
const_node_proto.add_output(const_params.name);
addLayer(const_params, const_node_proto);
node_proto.set_input(i, const_params.name);
}
}
else
firstInpDims = outShapes[node_proto.input(0)].size();
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
{
Mat blob = getBlob(node_proto, 1);
Mat transBlob;
secondInpDims = blob.dims;
// create order transposing last 2 dimensions
std::vector<int> order(secondInpDims);
std::iota(order.begin(), order.end(), 0);
std::swap(order[secondInpDims - 2], order[secondInpDims - 1]);
transposeND(blob, order, transBlob);
layerParams.blobs.push_back(transBlob);
int numOutput = layerParams.blobs[0].total(0, secondInpDims - 1);
layerParams.set("num_output", numOutput);
layerParams.set("is_matmul", secondInpDims > 2);
} else
secondInpDims = outShapes[node_proto.input(1)].size();
layerParams.set("axis", firstInpDims - 1);
addLayer(layerParams, node_proto);
}