add MatMulOp

This commit is contained in:
YashasSamaga 2021-05-22 01:01:29 +05:30
parent 7de627c504
commit 32df5faa25
6 changed files with 349 additions and 29 deletions

View File

@ -247,6 +247,122 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
); );
} }
/** @brief Strided batched GEMM for colummn-major matrices
*
* \f$ C_i = \alpha A_i B_i + \beta C_i \f$ for a stack of matrices A, B and C indexed by i
*
* @tparam T matrix element type (must be `half` or `float`)
*
* @param handle valid cuBLAS Handle
* @param transa use transposed matrix of A_i for computation
* @param transb use transposed matrix of B_i for computation
* @param rows_c number of rows in C_i
* @param cols_c number of columns in C_i
* @param common_dim common dimension of A_i (or trans A_i) and B_i (or trans B_i)
* @param alpha scale factor for A_i B_i
* @param[in] A pointer to stack of column-major matrices A in device memory
* @param lda leading dimension of matrix A_i
* @param strideA stride between matrices in A
* @param[in] B pointer to stack of column-major matrices B in device memory
* @param ldb leading dimension of matrix B_i
* @param strideB stride between matrices in B
* @param beta scale factor for C_i
* @param[in,out] C pointer to stack of column-major matrices C in device memory
* @param ldc leading dimension of matrix C_i
* @param strideC stride between matrices in C
* @param batchCount number of matrices in the batch
*
* Exception Guarantee: Basic
*/
template <class T>
void gemmStridedBatched(const Handle& handle,
bool transa, bool transb,
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
T alpha, const DevicePtr<const T> A, std::size_t lda, std::size_t strideA,
const DevicePtr<const T> B, std::size_t ldb, std::size_t strideB,
T beta, const DevicePtr<T> C, std::size_t ldc, std::size_t strideC,
std::size_t batchCount);
template <> inline
void gemmStridedBatched<half>(const Handle& handle,
bool transa, bool transb,
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
half alpha, const DevicePtr<const half> A, std::size_t lda, std::size_t strideA,
const DevicePtr<const half> B, std::size_t ldb, std::size_t strideB,
half beta, const DevicePtr<half> C, std::size_t ldc, std::size_t strideC,
std::size_t batchCount)
{
CV_Assert(handle);
const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto irows_c = static_cast<int>(rows_c),
icols_c = static_cast<int>(cols_c),
icommon_dim = static_cast<int>(common_dim),
ilda = static_cast<int>(lda),
ildb = static_cast<int>(ldb),
ildc = static_cast<int>(ldc);
const auto batch_count = static_cast<int>(batchCount);
const auto stride_a = static_cast<long long int>(strideA),
stride_b = static_cast<long long int>(strideB),
stride_c = static_cast<long long int>(strideC);
CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap
CUDA4DNN_CHECK_CUBLAS(
cublasHgemmStridedBatched(
handle.get(),
opa, opb,
irows_c, icols_c, icommon_dim,
&alpha, A.get(), ilda, stride_a,
B.get(), ildb, stride_b,
&beta, C.get(), ildc, stride_c,
batch_count
)
);
}
template <> inline
void gemmStridedBatched<float>(const Handle& handle,
bool transa, bool transb,
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
float alpha, const DevicePtr<const float> A, std::size_t lda, std::size_t strideA,
const DevicePtr<const float> B, std::size_t ldb, std::size_t strideB,
float beta, const DevicePtr<float> C, std::size_t ldc, std::size_t strideC,
std::size_t batchCount)
{
CV_Assert(handle);
const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto irows_c = static_cast<int>(rows_c),
icols_c = static_cast<int>(cols_c),
icommon_dim = static_cast<int>(common_dim),
ilda = static_cast<int>(lda),
ildb = static_cast<int>(ldb),
ildc = static_cast<int>(ldc);
const auto batch_count = static_cast<int>(batchCount);
const auto stride_a = static_cast<long long int>(strideA),
stride_b = static_cast<long long int>(strideB),
stride_c = static_cast<long long int>(strideC);
CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap
CUDA4DNN_CHECK_CUBLAS(
cublasSgemmStridedBatched(
handle.get(),
opa, opb,
irows_c, icols_c, icommon_dim,
&alpha, A.get(), ilda, stride_a,
B.get(), ildb, stride_b,
&beta, C.get(), ildc, stride_c,
batch_count
)
);
}
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */ }}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */ #endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */

View File

@ -369,6 +369,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
shape.erase(std::begin(shape) + axis); shape.erase(std::begin(shape) + axis);
} }
/** @brief squeezes the tensor
*
* removes leading singleton axes until the tensor's rank is equal to the requested rank
*
* Pre-conditions:
* - the tensor must be non-empty
* - the tensor's rank must be at least two
* - the tensor's rank must be at least the requested rank
* - the tensor must be squeezable up to the requested rank
*
* Exception Guarantee: Strong
*/
void squeeze_to(int r) {
CV_Assert(!empty());
CV_Assert(rank() >= r);
CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
shape.resize(r);
}
/** @brief unsqueezes the tensor /** @brief unsqueezes the tensor
* *
* adds a axis of unit size at the requested before the specified axis * adds a axis of unit size at the requested before the specified axis
@ -665,6 +685,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
shape.erase(std::begin(shape) + axis); shape.erase(std::begin(shape) + axis);
} }
/** @brief squeezes the tensor
*
* removes leading singleton axes until the tensor's rank is equal to the requested rank
*
* Pre-conditions:
* - the tensor must be non-empty
* - the tensor's rank must be at least two
* - the tensor's rank must be at least the requested rank
* - the tensor must be squeezable up to the requested rank
*
* Exception Guarantee: Strong
*/
void squeeze_to(int r) {
CV_Assert(!empty());
CV_Assert(rank() >= r);
CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
shape.resize(r);
}
/** @brief unsqueezes the tensor /** @brief unsqueezes the tensor
* *
* adds a axis of unit size at the requested before the specified axis * adds a axis of unit size at the requested before the specified axis
@ -1010,6 +1050,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
shape.erase(std::begin(shape) + axis); shape.erase(std::begin(shape) + axis);
} }
/** @brief squeezes the tensor
*
* removes leading singleton axes until the tensor's rank is equal to the requested rank
*
* Pre-conditions:
* - the tensor must be non-empty
* - the tensor's rank must be at least two
* - the tensor's rank must be at least the requested rank
* - the tensor must be squeezable up to the requested rank
*
* Exception Guarantee: Strong
*/
void squeeze_to(int r) {
CV_Assert(!empty());
CV_Assert(rank() >= r);
CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
shape.resize(r);
}
/** @brief unsqueezes the tensor /** @brief unsqueezes the tensor
* *
* adds a axis of unit size at the requested before the specified axis * adds a axis of unit size at the requested before the specified axis

View File

@ -44,21 +44,9 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
memcpy(dest.get(), src.get(), dest.size(), stream); memcpy(dest.get(), src.get(), dest.size(), stream);
} }
/** @brief performs generalized matrix-multiplication namespace detail {
* template <class T>
* Pre-conditions: void assertGEMMCompatiblity(const TensorSpan<T>& result, bool transa, const TensorView<T>& A, bool transb, const TensorView<T>& B) {
* - \p A and \p B must meet the mathematical requirements for matrix multiplication
* - \p result must be large enough to hold the result
*
* Exception Guarantee: Basic
*/
template <class T> inline
void gemm(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
/* matrix operations can be performed only on rank two or less tensors */
CV_Assert(get_effective_rank(A) <= 2 &&
get_effective_rank(B) <= 2 &&
get_effective_rank(result) <= 2);
/* check dimension requirements for matrix multiplication */ /* check dimension requirements for matrix multiplication */
if (!transa && !transb) { if (!transa && !transb) {
CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2)); CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
@ -77,6 +65,23 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1)); CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1));
CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1)); CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
} }
}
}
/** @brief performs generalized matrix-multiplication
*
* Pre-conditions:
* - \p A and \p B must meet the mathematical requirements for matrix multiplication
* - \p result must be large enough to hold the result
*
* Exception Guarantee: Basic
*/
template <class T> inline
void gemm(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
/* matrix operations can be performed only on tensors with rank two or below */
CV_Assert(get_effective_rank(A) <= 2);
CV_Assert(get_effective_rank(B) <= 2);
CV_Assert(get_effective_rank(result) <= 2);
const auto result_nr = result.get_axis_size(-2); const auto result_nr = result.get_axis_size(-2);
const auto result_nc = result.get_axis_size(-1); const auto result_nc = result.get_axis_size(-1);
@ -84,6 +89,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
const auto A_nc = A.get_axis_size(-1); const auto A_nc = A.get_axis_size(-1);
const auto B_nc = B.get_axis_size(-1); const auto B_nc = B.get_axis_size(-1);
detail::assertGEMMCompatiblity(result, transa, A, transb, B);
/* tensors are stored in row-major but cublas::gemm operates on column-major matrices /* tensors are stored in row-major but cublas::gemm operates on column-major matrices
* a row-major matrix when read as column-major matrix gives the transpose of the intended matrix * a row-major matrix when read as column-major matrix gives the transpose of the intended matrix
* *
@ -103,6 +110,47 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
beta, result.get(), result_nc); beta, result.get(), result_nc);
} }
/** @brief performs generalized matrix-multiplication for a strided batch of matrices
*
* Pre-conditions:
* - A, B and C must be rank three tensors with dimensions (batch, rows, cols)
* - the last two axes of \p A and \p B must meet the mathematical requirements for matrix multiplication
* - \p result must be large enough to hold the result and the matrices must not overlap in memory
* - batch dimension should be same in \p A, \p B and \p result
*
* Exception Guarantee: Basic
*/
template <class T> inline
void gemmStridedBatched(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
CV_Assert(A.rank() == 3);
CV_Assert(B.rank() == 3);
CV_Assert(result.rank() == 3);
const auto batch_size = result.get_axis_size(0);
CV_Assert(batch_size == A.get_axis_size(0));
CV_Assert(batch_size == B.get_axis_size(0));
detail::assertGEMMCompatiblity(result, transa, A, transb, B);
const auto result_nr = result.get_axis_size(-2);
const auto result_nc = result.get_axis_size(-1);
const auto common_dim = A.get_axis_size(transa ? -2 : -1);
const auto A_nc = A.get_axis_size(-1);
const auto B_nc = B.get_axis_size(-1);
std::size_t strideA = (A.size() / batch_size),
strideB = (B.size() / batch_size),
strideC = (result.size() / batch_size);
cublas::gemmStridedBatched<T>(handle,
transb, transa,
result_nc, result_nr, common_dim,
alpha, B.get(), B_nc, strideB,
A.get(), A_nc, strideA,
beta, result.get(), result_nc, strideC,
batch_size);
}
/** @brief performs element-wise addition with broadcasting /** @brief performs element-wise addition with broadcasting
* *
* Pre-conditions: * Pre-conditions:

View File

@ -0,0 +1,95 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP
#include "../../op_cuda.hpp"
#include "../csl/stream.hpp"
#include "../csl/cublas.hpp"
#include "../csl/tensor.hpp"
#include "../csl/tensor_ops.hpp"
#include <opencv2/core.hpp>
#include <utility>
namespace cv { namespace dnn { namespace cuda4dnn {
template <class T>
class MatMulOp final : public CUDABackendNode {
public:
using wrapper_type = GetCUDABackendWrapperType<T>;
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle)
: stream(std::move(stream_)), cublasHandle(std::move(handle))
{
}
void forward(
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
csl::Workspace& workspace) override
{
CV_Assert(inputs.size() == 2 && outputs.size() == 1);
auto input1_wrapper = inputs[0].dynamicCast<wrapper_type>();
auto input1 = input1_wrapper->getView();
auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
auto input2 = input2_wrapper->getView();
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
auto output = output_wrapper->getSpan();
auto rank = output.rank();
CV_Assert(rank == input1.rank());
CV_Assert(rank == input2.rank());
CV_Assert(rank >= 2); // 1D MatMul not supported
for (int i = 0; i < rank - 2; i++)
{
// broadcasting not supported
auto size = output.get_axis_size(i);
CV_Assert(input1.get_axis_size(i) == size);
CV_Assert(input2.get_axis_size(i) == size);
}
auto m = input1.get_axis_size(-2);
auto n = input1.get_axis_size(-1);
auto k = input2.get_axis_size(-1);
auto b = input1.size() / m / n;
CV_Assert(input2.get_axis_size(-2) == n);
CV_Assert(output.get_axis_size(-2) == m);
CV_Assert(output.get_axis_size(-1) == k);
if (get_effective_rank(output) <= 2)
{
CV_Assert(b == 1);
CV_Assert(get_effective_rank(input1) <= 2);
CV_Assert(get_effective_rank(input2) <= 2);
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
}
else
{
CV_Assert(rank >= 3);
input1.reshape(b, m, n);
input2.reshape(b, n, k);
output.reshape(b, m, k);
input1.squeeze_to(3);
input2.squeeze_to(3);
output.squeeze_to(3);
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
}
}
private:
csl::Stream stream;
csl::cublas::Handle cublasHandle;
};
}}} /* namespace cv::dnn::cuda4dnn */
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP */

View File

@ -55,6 +55,7 @@ using namespace cv::dnn::ocl4dnn;
#endif #endif
#ifdef HAVE_CUDA #ifdef HAVE_CUDA
#include "../cuda4dnn/primitives/matmul.hpp"
#include "../cuda4dnn/primitives/inner_product.hpp" #include "../cuda4dnn/primitives/inner_product.hpp"
using namespace cv::dnn::cuda4dnn; using namespace cv::dnn::cuda4dnn;
#endif #endif
@ -523,10 +524,14 @@ public:
{ {
auto context = reinterpret_cast<csl::CSLContext*>(context_); auto context = reinterpret_cast<csl::CSLContext*>(context_);
if (weightsMat.empty())
{
CV_Assert(!bias);
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle));
}
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>(); auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank()); auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
auto biasMat_ = bias ? biasMat : Mat(); auto biasMat_ = bias ? biasMat : Mat();
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_); return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
} }

View File

@ -485,8 +485,6 @@ TEST_P(Test_ONNX_layers, MatMul)
{ {
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019) if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER); applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
if (backend == DNN_BACKEND_CUDA)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); // not supported
testONNXModels("matmul_2d"); testONNXModels("matmul_2d");
testONNXModels("matmul_3d"); testONNXModels("matmul_3d");
@ -735,8 +733,6 @@ TEST_P(Test_ONNX_layers, MatmulWithTwoInputs)
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2020040000) #if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2020040000)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE); applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
#endif #endif
if (backend == DNN_BACKEND_CUDA)
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
testONNXModels("matmul_with_two_inputs"); testONNXModels("matmul_with_two_inputs");
} }