From 32df5faa25411b9317801c17b6c2672b6c659d3b Mon Sep 17 00:00:00 2001 From: YashasSamaga Date: Sat, 22 May 2021 01:01:29 +0530 Subject: [PATCH] add MatMulOp --- modules/dnn/src/cuda4dnn/csl/cublas.hpp | 116 ++++++++++++++++++ modules/dnn/src/cuda4dnn/csl/tensor.hpp | 60 +++++++++ modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp | 94 ++++++++++---- .../dnn/src/cuda4dnn/primitives/matmul.hpp | 95 ++++++++++++++ .../dnn/src/layers/fully_connected_layer.cpp | 9 +- modules/dnn/test/test_onnx_importer.cpp | 4 - 6 files changed, 349 insertions(+), 29 deletions(-) create mode 100644 modules/dnn/src/cuda4dnn/primitives/matmul.hpp diff --git a/modules/dnn/src/cuda4dnn/csl/cublas.hpp b/modules/dnn/src/cuda4dnn/csl/cublas.hpp index 2928cda779..760e3824fd 100644 --- a/modules/dnn/src/cuda4dnn/csl/cublas.hpp +++ b/modules/dnn/src/cuda4dnn/csl/cublas.hpp @@ -247,6 +247,122 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu ); } + /** @brief Strided batched GEMM for colummn-major matrices + * + * \f$ C_i = \alpha A_i B_i + \beta C_i \f$ for a stack of matrices A, B and C indexed by i + * + * @tparam T matrix element type (must be `half` or `float`) + * + * @param handle valid cuBLAS Handle + * @param transa use transposed matrix of A_i for computation + * @param transb use transposed matrix of B_i for computation + * @param rows_c number of rows in C_i + * @param cols_c number of columns in C_i + * @param common_dim common dimension of A_i (or trans A_i) and B_i (or trans B_i) + * @param alpha scale factor for A_i B_i + * @param[in] A pointer to stack of column-major matrices A in device memory + * @param lda leading dimension of matrix A_i + * @param strideA stride between matrices in A + * @param[in] B pointer to stack of column-major matrices B in device memory + * @param ldb leading dimension of matrix B_i + * @param strideB stride between matrices in B + * @param beta scale factor for C_i + * @param[in,out] C pointer to stack of column-major matrices C in device memory + * @param ldc leading dimension of matrix C_i + * @param strideC stride between matrices in C + * @param batchCount number of matrices in the batch + * + * Exception Guarantee: Basic + */ + template + void gemmStridedBatched(const Handle& handle, + bool transa, bool transb, + std::size_t rows_c, std::size_t cols_c, std::size_t common_dim, + T alpha, const DevicePtr A, std::size_t lda, std::size_t strideA, + const DevicePtr B, std::size_t ldb, std::size_t strideB, + T beta, const DevicePtr C, std::size_t ldc, std::size_t strideC, + std::size_t batchCount); + + template <> inline + void gemmStridedBatched(const Handle& handle, + bool transa, bool transb, + std::size_t rows_c, std::size_t cols_c, std::size_t common_dim, + half alpha, const DevicePtr A, std::size_t lda, std::size_t strideA, + const DevicePtr B, std::size_t ldb, std::size_t strideB, + half beta, const DevicePtr C, std::size_t ldc, std::size_t strideC, + std::size_t batchCount) + { + CV_Assert(handle); + + const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N, + opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto irows_c = static_cast(rows_c), + icols_c = static_cast(cols_c), + icommon_dim = static_cast(common_dim), + ilda = static_cast(lda), + ildb = static_cast(ldb), + ildc = static_cast(ldc); + + const auto batch_count = static_cast(batchCount); + const auto stride_a = static_cast(strideA), + stride_b = static_cast(strideB), + stride_c = static_cast(strideC); + + CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap + + CUDA4DNN_CHECK_CUBLAS( + cublasHgemmStridedBatched( + handle.get(), + opa, opb, + irows_c, icols_c, icommon_dim, + &alpha, A.get(), ilda, stride_a, + B.get(), ildb, stride_b, + &beta, C.get(), ildc, stride_c, + batch_count + ) + ); + } + + template <> inline + void gemmStridedBatched(const Handle& handle, + bool transa, bool transb, + std::size_t rows_c, std::size_t cols_c, std::size_t common_dim, + float alpha, const DevicePtr A, std::size_t lda, std::size_t strideA, + const DevicePtr B, std::size_t ldb, std::size_t strideB, + float beta, const DevicePtr C, std::size_t ldc, std::size_t strideC, + std::size_t batchCount) + { + CV_Assert(handle); + + const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N, + opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto irows_c = static_cast(rows_c), + icols_c = static_cast(cols_c), + icommon_dim = static_cast(common_dim), + ilda = static_cast(lda), + ildb = static_cast(ldb), + ildc = static_cast(ldc); + + const auto batch_count = static_cast(batchCount); + const auto stride_a = static_cast(strideA), + stride_b = static_cast(strideB), + stride_c = static_cast(strideC); + + CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap + + CUDA4DNN_CHECK_CUBLAS( + cublasSgemmStridedBatched( + handle.get(), + opa, opb, + irows_c, icols_c, icommon_dim, + &alpha, A.get(), ilda, stride_a, + B.get(), ildb, stride_b, + &beta, C.get(), ildc, stride_c, + batch_count + ) + ); + } + }}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */ #endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */ diff --git a/modules/dnn/src/cuda4dnn/csl/tensor.hpp b/modules/dnn/src/cuda4dnn/csl/tensor.hpp index 6e997ab0eb..5a1286de99 100644 --- a/modules/dnn/src/cuda4dnn/csl/tensor.hpp +++ b/modules/dnn/src/cuda4dnn/csl/tensor.hpp @@ -369,6 +369,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { shape.erase(std::begin(shape) + axis); } + /** @brief squeezes the tensor + * + * removes leading singleton axes until the tensor's rank is equal to the requested rank + * + * Pre-conditions: + * - the tensor must be non-empty + * - the tensor's rank must be at least two + * - the tensor's rank must be at least the requested rank + * - the tensor must be squeezable up to the requested rank + * + * Exception Guarantee: Strong + */ + void squeeze_to(int r) { + CV_Assert(!empty()); + CV_Assert(rank() >= r); + CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; })); + std::copy(std::end(shape) - r, std::end(shape), std::begin(shape)); + shape.resize(r); + } + /** @brief unsqueezes the tensor * * adds a axis of unit size at the requested before the specified axis @@ -665,6 +685,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { shape.erase(std::begin(shape) + axis); } + /** @brief squeezes the tensor + * + * removes leading singleton axes until the tensor's rank is equal to the requested rank + * + * Pre-conditions: + * - the tensor must be non-empty + * - the tensor's rank must be at least two + * - the tensor's rank must be at least the requested rank + * - the tensor must be squeezable up to the requested rank + * + * Exception Guarantee: Strong + */ + void squeeze_to(int r) { + CV_Assert(!empty()); + CV_Assert(rank() >= r); + CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; })); + std::copy(std::end(shape) - r, std::end(shape), std::begin(shape)); + shape.resize(r); + } + /** @brief unsqueezes the tensor * * adds a axis of unit size at the requested before the specified axis @@ -1010,6 +1050,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { shape.erase(std::begin(shape) + axis); } + /** @brief squeezes the tensor + * + * removes leading singleton axes until the tensor's rank is equal to the requested rank + * + * Pre-conditions: + * - the tensor must be non-empty + * - the tensor's rank must be at least two + * - the tensor's rank must be at least the requested rank + * - the tensor must be squeezable up to the requested rank + * + * Exception Guarantee: Strong + */ + void squeeze_to(int r) { + CV_Assert(!empty()); + CV_Assert(rank() >= r); + CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; })); + std::copy(std::end(shape) - r, std::end(shape), std::begin(shape)); + shape.resize(r); + } + /** @brief unsqueezes the tensor * * adds a axis of unit size at the requested before the specified axis diff --git a/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp b/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp index aeddaf353b..4ee0e8ab77 100644 --- a/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp +++ b/modules/dnn/src/cuda4dnn/csl/tensor_ops.hpp @@ -44,6 +44,30 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { memcpy(dest.get(), src.get(), dest.size(), stream); } + namespace detail { + template + void assertGEMMCompatiblity(const TensorSpan& result, bool transa, const TensorView& A, bool transb, const TensorView& B) { + /* check dimension requirements for matrix multiplication */ + if (!transa && !transb) { + CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2)); + CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1)); + } else if (!transa && transb) { + CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1)); + CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); + } else if (transa && !transb) { + CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2)); + CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1)); + } else { + CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2)); + CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1)); + CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); + } + } + } + /** @brief performs generalized matrix-multiplication * * Pre-conditions: @@ -54,29 +78,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { */ template inline void gemm(const cublas::Handle& handle, T beta, TensorSpan result, T alpha, bool transa, TensorView A, bool transb, TensorView B) { - /* matrix operations can be performed only on rank two or less tensors */ - CV_Assert(get_effective_rank(A) <= 2 && - get_effective_rank(B) <= 2 && - get_effective_rank(result) <= 2); - - /* check dimension requirements for matrix multiplication */ - if (!transa && !transb) { - CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); - CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2)); - CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1)); - } else if (!transa && transb) { - CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); - CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1)); - CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); - } else if (transa && !transb) { - CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2)); - CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2)); - CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1)); - } else { - CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2)); - CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1)); - CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); - } + /* matrix operations can be performed only on tensors with rank two or below */ + CV_Assert(get_effective_rank(A) <= 2); + CV_Assert(get_effective_rank(B) <= 2); + CV_Assert(get_effective_rank(result) <= 2); const auto result_nr = result.get_axis_size(-2); const auto result_nc = result.get_axis_size(-1); @@ -84,6 +89,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { const auto A_nc = A.get_axis_size(-1); const auto B_nc = B.get_axis_size(-1); + detail::assertGEMMCompatiblity(result, transa, A, transb, B); + /* tensors are stored in row-major but cublas::gemm operates on column-major matrices * a row-major matrix when read as column-major matrix gives the transpose of the intended matrix * @@ -103,6 +110,47 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { beta, result.get(), result_nc); } + /** @brief performs generalized matrix-multiplication for a strided batch of matrices + * + * Pre-conditions: + * - A, B and C must be rank three tensors with dimensions (batch, rows, cols) + * - the last two axes of \p A and \p B must meet the mathematical requirements for matrix multiplication + * - \p result must be large enough to hold the result and the matrices must not overlap in memory + * - batch dimension should be same in \p A, \p B and \p result + * + * Exception Guarantee: Basic + */ + template inline + void gemmStridedBatched(const cublas::Handle& handle, T beta, TensorSpan result, T alpha, bool transa, TensorView A, bool transb, TensorView B) { + CV_Assert(A.rank() == 3); + CV_Assert(B.rank() == 3); + CV_Assert(result.rank() == 3); + + const auto batch_size = result.get_axis_size(0); + CV_Assert(batch_size == A.get_axis_size(0)); + CV_Assert(batch_size == B.get_axis_size(0)); + + detail::assertGEMMCompatiblity(result, transa, A, transb, B); + + const auto result_nr = result.get_axis_size(-2); + const auto result_nc = result.get_axis_size(-1); + const auto common_dim = A.get_axis_size(transa ? -2 : -1); + const auto A_nc = A.get_axis_size(-1); + const auto B_nc = B.get_axis_size(-1); + + std::size_t strideA = (A.size() / batch_size), + strideB = (B.size() / batch_size), + strideC = (result.size() / batch_size); + + cublas::gemmStridedBatched(handle, + transb, transa, + result_nc, result_nr, common_dim, + alpha, B.get(), B_nc, strideB, + A.get(), A_nc, strideA, + beta, result.get(), result_nc, strideC, + batch_size); + } + /** @brief performs element-wise addition with broadcasting * * Pre-conditions: diff --git a/modules/dnn/src/cuda4dnn/primitives/matmul.hpp b/modules/dnn/src/cuda4dnn/primitives/matmul.hpp new file mode 100644 index 0000000000..e29036d5f4 --- /dev/null +++ b/modules/dnn/src/cuda4dnn/primitives/matmul.hpp @@ -0,0 +1,95 @@ +// This file is part of OpenCV project. +// It is subject to the license terms in the LICENSE file found in the top-level directory +// of this distribution and at http://opencv.org/license.html. + +#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP +#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP + +#include "../../op_cuda.hpp" + +#include "../csl/stream.hpp" +#include "../csl/cublas.hpp" +#include "../csl/tensor.hpp" +#include "../csl/tensor_ops.hpp" + +#include + +#include + +namespace cv { namespace dnn { namespace cuda4dnn { + + template + class MatMulOp final : public CUDABackendNode { + public: + using wrapper_type = GetCUDABackendWrapperType; + + MatMulOp(csl::Stream stream_, csl::cublas::Handle handle) + : stream(std::move(stream_)), cublasHandle(std::move(handle)) + { + } + + void forward( + const std::vector>& inputs, + const std::vector>& outputs, + csl::Workspace& workspace) override + { + CV_Assert(inputs.size() == 2 && outputs.size() == 1); + + auto input1_wrapper = inputs[0].dynamicCast(); + auto input1 = input1_wrapper->getView(); + + auto input2_wrapper = inputs[1].dynamicCast(); + auto input2 = input2_wrapper->getView(); + + auto output_wrapper = outputs[0].dynamicCast(); + auto output = output_wrapper->getSpan(); + + auto rank = output.rank(); + CV_Assert(rank == input1.rank()); + CV_Assert(rank == input2.rank()); + CV_Assert(rank >= 2); // 1D MatMul not supported + + for (int i = 0; i < rank - 2; i++) + { + // broadcasting not supported + auto size = output.get_axis_size(i); + CV_Assert(input1.get_axis_size(i) == size); + CV_Assert(input2.get_axis_size(i) == size); + } + + auto m = input1.get_axis_size(-2); + auto n = input1.get_axis_size(-1); + auto k = input2.get_axis_size(-1); + auto b = input1.size() / m / n; + CV_Assert(input2.get_axis_size(-2) == n); + CV_Assert(output.get_axis_size(-2) == m); + CV_Assert(output.get_axis_size(-1) == k); + + if (get_effective_rank(output) <= 2) + { + CV_Assert(b == 1); + CV_Assert(get_effective_rank(input1) <= 2); + CV_Assert(get_effective_rank(input2) <= 2); + csl::tensor_ops::gemm(cublasHandle, 0.0, output, 1.0, false, input1, false, input2); + } + else + { + CV_Assert(rank >= 3); + input1.reshape(b, m, n); + input2.reshape(b, n, k); + output.reshape(b, m, k); + input1.squeeze_to(3); + input2.squeeze_to(3); + output.squeeze_to(3); + csl::tensor_ops::gemmStridedBatched(cublasHandle, 0.0, output, 1.0, false, input1, false, input2); + } + } + + private: + csl::Stream stream; + csl::cublas::Handle cublasHandle; + }; + +}}} /* namespace cv::dnn::cuda4dnn */ + +#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP */ diff --git a/modules/dnn/src/layers/fully_connected_layer.cpp b/modules/dnn/src/layers/fully_connected_layer.cpp index 709420c3ca..d9c1fa65c1 100644 --- a/modules/dnn/src/layers/fully_connected_layer.cpp +++ b/modules/dnn/src/layers/fully_connected_layer.cpp @@ -55,6 +55,7 @@ using namespace cv::dnn::ocl4dnn; #endif #ifdef HAVE_CUDA +#include "../cuda4dnn/primitives/matmul.hpp" #include "../cuda4dnn/primitives/inner_product.hpp" using namespace cv::dnn::cuda4dnn; #endif @@ -523,10 +524,14 @@ public: { auto context = reinterpret_cast(context_); + if (weightsMat.empty()) + { + CV_Assert(!bias); + return make_cuda_node(preferableTarget, std::move(context->stream), std::move(context->cublas_handle)); + } + auto input_wrapper = inputs[0].dynamicCast(); - auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank()); - auto biasMat_ = bias ? biasMat : Mat(); return make_cuda_node(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 81ea1dcdd0..22a504df69 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -485,8 +485,6 @@ TEST_P(Test_ONNX_layers, MatMul) { if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); - if (backend == DNN_BACKEND_CUDA) - applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); // not supported testONNXModels("matmul_2d"); testONNXModels("matmul_3d"); @@ -735,8 +733,6 @@ TEST_P(Test_ONNX_layers, MatmulWithTwoInputs) #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2020040000) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE); #endif - if (backend == DNN_BACKEND_CUDA) - applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); testONNXModels("matmul_with_two_inputs"); }