mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 11:40:44 +08:00
add MatMulOp
This commit is contained in:
parent
7de627c504
commit
32df5faa25
@ -247,6 +247,122 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl { namespace cu
|
||||
);
|
||||
}
|
||||
|
||||
/** @brief Strided batched GEMM for colummn-major matrices
|
||||
*
|
||||
* \f$ C_i = \alpha A_i B_i + \beta C_i \f$ for a stack of matrices A, B and C indexed by i
|
||||
*
|
||||
* @tparam T matrix element type (must be `half` or `float`)
|
||||
*
|
||||
* @param handle valid cuBLAS Handle
|
||||
* @param transa use transposed matrix of A_i for computation
|
||||
* @param transb use transposed matrix of B_i for computation
|
||||
* @param rows_c number of rows in C_i
|
||||
* @param cols_c number of columns in C_i
|
||||
* @param common_dim common dimension of A_i (or trans A_i) and B_i (or trans B_i)
|
||||
* @param alpha scale factor for A_i B_i
|
||||
* @param[in] A pointer to stack of column-major matrices A in device memory
|
||||
* @param lda leading dimension of matrix A_i
|
||||
* @param strideA stride between matrices in A
|
||||
* @param[in] B pointer to stack of column-major matrices B in device memory
|
||||
* @param ldb leading dimension of matrix B_i
|
||||
* @param strideB stride between matrices in B
|
||||
* @param beta scale factor for C_i
|
||||
* @param[in,out] C pointer to stack of column-major matrices C in device memory
|
||||
* @param ldc leading dimension of matrix C_i
|
||||
* @param strideC stride between matrices in C
|
||||
* @param batchCount number of matrices in the batch
|
||||
*
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
template <class T>
|
||||
void gemmStridedBatched(const Handle& handle,
|
||||
bool transa, bool transb,
|
||||
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
|
||||
T alpha, const DevicePtr<const T> A, std::size_t lda, std::size_t strideA,
|
||||
const DevicePtr<const T> B, std::size_t ldb, std::size_t strideB,
|
||||
T beta, const DevicePtr<T> C, std::size_t ldc, std::size_t strideC,
|
||||
std::size_t batchCount);
|
||||
|
||||
template <> inline
|
||||
void gemmStridedBatched<half>(const Handle& handle,
|
||||
bool transa, bool transb,
|
||||
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
|
||||
half alpha, const DevicePtr<const half> A, std::size_t lda, std::size_t strideA,
|
||||
const DevicePtr<const half> B, std::size_t ldb, std::size_t strideB,
|
||||
half beta, const DevicePtr<half> C, std::size_t ldc, std::size_t strideC,
|
||||
std::size_t batchCount)
|
||||
{
|
||||
CV_Assert(handle);
|
||||
|
||||
const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const auto irows_c = static_cast<int>(rows_c),
|
||||
icols_c = static_cast<int>(cols_c),
|
||||
icommon_dim = static_cast<int>(common_dim),
|
||||
ilda = static_cast<int>(lda),
|
||||
ildb = static_cast<int>(ldb),
|
||||
ildc = static_cast<int>(ldc);
|
||||
|
||||
const auto batch_count = static_cast<int>(batchCount);
|
||||
const auto stride_a = static_cast<long long int>(strideA),
|
||||
stride_b = static_cast<long long int>(strideB),
|
||||
stride_c = static_cast<long long int>(strideC);
|
||||
|
||||
CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap
|
||||
|
||||
CUDA4DNN_CHECK_CUBLAS(
|
||||
cublasHgemmStridedBatched(
|
||||
handle.get(),
|
||||
opa, opb,
|
||||
irows_c, icols_c, icommon_dim,
|
||||
&alpha, A.get(), ilda, stride_a,
|
||||
B.get(), ildb, stride_b,
|
||||
&beta, C.get(), ildc, stride_c,
|
||||
batch_count
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
template <> inline
|
||||
void gemmStridedBatched<float>(const Handle& handle,
|
||||
bool transa, bool transb,
|
||||
std::size_t rows_c, std::size_t cols_c, std::size_t common_dim,
|
||||
float alpha, const DevicePtr<const float> A, std::size_t lda, std::size_t strideA,
|
||||
const DevicePtr<const float> B, std::size_t ldb, std::size_t strideB,
|
||||
float beta, const DevicePtr<float> C, std::size_t ldc, std::size_t strideC,
|
||||
std::size_t batchCount)
|
||||
{
|
||||
CV_Assert(handle);
|
||||
|
||||
const auto opa = transa ? CUBLAS_OP_T : CUBLAS_OP_N,
|
||||
opb = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const auto irows_c = static_cast<int>(rows_c),
|
||||
icols_c = static_cast<int>(cols_c),
|
||||
icommon_dim = static_cast<int>(common_dim),
|
||||
ilda = static_cast<int>(lda),
|
||||
ildb = static_cast<int>(ldb),
|
||||
ildc = static_cast<int>(ldc);
|
||||
|
||||
const auto batch_count = static_cast<int>(batchCount);
|
||||
const auto stride_a = static_cast<long long int>(strideA),
|
||||
stride_b = static_cast<long long int>(strideB),
|
||||
stride_c = static_cast<long long int>(strideC);
|
||||
|
||||
CV_Assert(stride_c >= irows_c * icols_c); // output matrices must not overlap
|
||||
|
||||
CUDA4DNN_CHECK_CUBLAS(
|
||||
cublasSgemmStridedBatched(
|
||||
handle.get(),
|
||||
opa, opb,
|
||||
irows_c, icols_c, icommon_dim,
|
||||
&alpha, A.get(), ilda, stride_a,
|
||||
B.get(), ildb, stride_b,
|
||||
&beta, C.get(), ildc, stride_c,
|
||||
batch_count
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
}}}}} /* namespace cv::dnn::cuda4dnn::csl::cublas */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_CSL_CUBLAS_HPP */
|
||||
|
@ -369,6 +369,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
shape.erase(std::begin(shape) + axis);
|
||||
}
|
||||
|
||||
/** @brief squeezes the tensor
|
||||
*
|
||||
* removes leading singleton axes until the tensor's rank is equal to the requested rank
|
||||
*
|
||||
* Pre-conditions:
|
||||
* - the tensor must be non-empty
|
||||
* - the tensor's rank must be at least two
|
||||
* - the tensor's rank must be at least the requested rank
|
||||
* - the tensor must be squeezable up to the requested rank
|
||||
*
|
||||
* Exception Guarantee: Strong
|
||||
*/
|
||||
void squeeze_to(int r) {
|
||||
CV_Assert(!empty());
|
||||
CV_Assert(rank() >= r);
|
||||
CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
|
||||
std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
|
||||
shape.resize(r);
|
||||
}
|
||||
|
||||
/** @brief unsqueezes the tensor
|
||||
*
|
||||
* adds a axis of unit size at the requested before the specified axis
|
||||
@ -665,6 +685,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
shape.erase(std::begin(shape) + axis);
|
||||
}
|
||||
|
||||
/** @brief squeezes the tensor
|
||||
*
|
||||
* removes leading singleton axes until the tensor's rank is equal to the requested rank
|
||||
*
|
||||
* Pre-conditions:
|
||||
* - the tensor must be non-empty
|
||||
* - the tensor's rank must be at least two
|
||||
* - the tensor's rank must be at least the requested rank
|
||||
* - the tensor must be squeezable up to the requested rank
|
||||
*
|
||||
* Exception Guarantee: Strong
|
||||
*/
|
||||
void squeeze_to(int r) {
|
||||
CV_Assert(!empty());
|
||||
CV_Assert(rank() >= r);
|
||||
CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
|
||||
std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
|
||||
shape.resize(r);
|
||||
}
|
||||
|
||||
/** @brief unsqueezes the tensor
|
||||
*
|
||||
* adds a axis of unit size at the requested before the specified axis
|
||||
@ -1010,6 +1050,26 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
shape.erase(std::begin(shape) + axis);
|
||||
}
|
||||
|
||||
/** @brief squeezes the tensor
|
||||
*
|
||||
* removes leading singleton axes until the tensor's rank is equal to the requested rank
|
||||
*
|
||||
* Pre-conditions:
|
||||
* - the tensor must be non-empty
|
||||
* - the tensor's rank must be at least two
|
||||
* - the tensor's rank must be at least the requested rank
|
||||
* - the tensor must be squeezable up to the requested rank
|
||||
*
|
||||
* Exception Guarantee: Strong
|
||||
*/
|
||||
void squeeze_to(int r) {
|
||||
CV_Assert(!empty());
|
||||
CV_Assert(rank() >= r);
|
||||
CV_Assert(std::all_of(std::begin(shape), std::end(shape) - r, [](size_type x){ return x == 1; }));
|
||||
std::copy(std::end(shape) - r, std::end(shape), std::begin(shape));
|
||||
shape.resize(r);
|
||||
}
|
||||
|
||||
/** @brief unsqueezes the tensor
|
||||
*
|
||||
* adds a axis of unit size at the requested before the specified axis
|
||||
|
@ -44,6 +44,30 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
memcpy(dest.get(), src.get(), dest.size(), stream);
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template <class T>
|
||||
void assertGEMMCompatiblity(const TensorSpan<T>& result, bool transa, const TensorView<T>& A, bool transb, const TensorView<T>& B) {
|
||||
/* check dimension requirements for matrix multiplication */
|
||||
if (!transa && !transb) {
|
||||
CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2));
|
||||
CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
|
||||
} else if (!transa && transb) {
|
||||
CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1));
|
||||
CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
|
||||
} else if (transa && !transb) {
|
||||
CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2));
|
||||
CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
|
||||
} else {
|
||||
CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1));
|
||||
CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** @brief performs generalized matrix-multiplication
|
||||
*
|
||||
* Pre-conditions:
|
||||
@ -54,29 +78,10 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
*/
|
||||
template <class T> inline
|
||||
void gemm(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
|
||||
/* matrix operations can be performed only on rank two or less tensors */
|
||||
CV_Assert(get_effective_rank(A) <= 2 &&
|
||||
get_effective_rank(B) <= 2 &&
|
||||
get_effective_rank(result) <= 2);
|
||||
|
||||
/* check dimension requirements for matrix multiplication */
|
||||
if (!transa && !transb) {
|
||||
CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-2));
|
||||
CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
|
||||
} else if (!transa && transb) {
|
||||
CV_Assert(A.get_axis_size(-2) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-1) == B.get_axis_size(-1));
|
||||
CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
|
||||
} else if (transa && !transb) {
|
||||
CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-2));
|
||||
CV_Assert(B.get_axis_size(-1) == result.get_axis_size(-1));
|
||||
} else {
|
||||
CV_Assert(A.get_axis_size(-1) == result.get_axis_size(-2));
|
||||
CV_Assert(A.get_axis_size(-2) == B.get_axis_size(-1));
|
||||
CV_Assert(B.get_axis_size(-2) == result.get_axis_size(-1));
|
||||
}
|
||||
/* matrix operations can be performed only on tensors with rank two or below */
|
||||
CV_Assert(get_effective_rank(A) <= 2);
|
||||
CV_Assert(get_effective_rank(B) <= 2);
|
||||
CV_Assert(get_effective_rank(result) <= 2);
|
||||
|
||||
const auto result_nr = result.get_axis_size(-2);
|
||||
const auto result_nc = result.get_axis_size(-1);
|
||||
@ -84,6 +89,8 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
const auto A_nc = A.get_axis_size(-1);
|
||||
const auto B_nc = B.get_axis_size(-1);
|
||||
|
||||
detail::assertGEMMCompatiblity(result, transa, A, transb, B);
|
||||
|
||||
/* tensors are stored in row-major but cublas::gemm operates on column-major matrices
|
||||
* a row-major matrix when read as column-major matrix gives the transpose of the intended matrix
|
||||
*
|
||||
@ -103,6 +110,47 @@ namespace cv { namespace dnn { namespace cuda4dnn { namespace csl {
|
||||
beta, result.get(), result_nc);
|
||||
}
|
||||
|
||||
/** @brief performs generalized matrix-multiplication for a strided batch of matrices
|
||||
*
|
||||
* Pre-conditions:
|
||||
* - A, B and C must be rank three tensors with dimensions (batch, rows, cols)
|
||||
* - the last two axes of \p A and \p B must meet the mathematical requirements for matrix multiplication
|
||||
* - \p result must be large enough to hold the result and the matrices must not overlap in memory
|
||||
* - batch dimension should be same in \p A, \p B and \p result
|
||||
*
|
||||
* Exception Guarantee: Basic
|
||||
*/
|
||||
template <class T> inline
|
||||
void gemmStridedBatched(const cublas::Handle& handle, T beta, TensorSpan<T> result, T alpha, bool transa, TensorView<T> A, bool transb, TensorView<T> B) {
|
||||
CV_Assert(A.rank() == 3);
|
||||
CV_Assert(B.rank() == 3);
|
||||
CV_Assert(result.rank() == 3);
|
||||
|
||||
const auto batch_size = result.get_axis_size(0);
|
||||
CV_Assert(batch_size == A.get_axis_size(0));
|
||||
CV_Assert(batch_size == B.get_axis_size(0));
|
||||
|
||||
detail::assertGEMMCompatiblity(result, transa, A, transb, B);
|
||||
|
||||
const auto result_nr = result.get_axis_size(-2);
|
||||
const auto result_nc = result.get_axis_size(-1);
|
||||
const auto common_dim = A.get_axis_size(transa ? -2 : -1);
|
||||
const auto A_nc = A.get_axis_size(-1);
|
||||
const auto B_nc = B.get_axis_size(-1);
|
||||
|
||||
std::size_t strideA = (A.size() / batch_size),
|
||||
strideB = (B.size() / batch_size),
|
||||
strideC = (result.size() / batch_size);
|
||||
|
||||
cublas::gemmStridedBatched<T>(handle,
|
||||
transb, transa,
|
||||
result_nc, result_nr, common_dim,
|
||||
alpha, B.get(), B_nc, strideB,
|
||||
A.get(), A_nc, strideA,
|
||||
beta, result.get(), result_nc, strideC,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
/** @brief performs element-wise addition with broadcasting
|
||||
*
|
||||
* Pre-conditions:
|
||||
|
95
modules/dnn/src/cuda4dnn/primitives/matmul.hpp
Normal file
95
modules/dnn/src/cuda4dnn/primitives/matmul.hpp
Normal file
@ -0,0 +1,95 @@
|
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP
|
||||
#define OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP
|
||||
|
||||
#include "../../op_cuda.hpp"
|
||||
|
||||
#include "../csl/stream.hpp"
|
||||
#include "../csl/cublas.hpp"
|
||||
#include "../csl/tensor.hpp"
|
||||
#include "../csl/tensor_ops.hpp"
|
||||
|
||||
#include <opencv2/core.hpp>
|
||||
|
||||
#include <utility>
|
||||
|
||||
namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
|
||||
template <class T>
|
||||
class MatMulOp final : public CUDABackendNode {
|
||||
public:
|
||||
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||
|
||||
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle)
|
||||
: stream(std::move(stream_)), cublasHandle(std::move(handle))
|
||||
{
|
||||
}
|
||||
|
||||
void forward(
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& inputs,
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||
csl::Workspace& workspace) override
|
||||
{
|
||||
CV_Assert(inputs.size() == 2 && outputs.size() == 1);
|
||||
|
||||
auto input1_wrapper = inputs[0].dynamicCast<wrapper_type>();
|
||||
auto input1 = input1_wrapper->getView();
|
||||
|
||||
auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
|
||||
auto input2 = input2_wrapper->getView();
|
||||
|
||||
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
|
||||
auto output = output_wrapper->getSpan();
|
||||
|
||||
auto rank = output.rank();
|
||||
CV_Assert(rank == input1.rank());
|
||||
CV_Assert(rank == input2.rank());
|
||||
CV_Assert(rank >= 2); // 1D MatMul not supported
|
||||
|
||||
for (int i = 0; i < rank - 2; i++)
|
||||
{
|
||||
// broadcasting not supported
|
||||
auto size = output.get_axis_size(i);
|
||||
CV_Assert(input1.get_axis_size(i) == size);
|
||||
CV_Assert(input2.get_axis_size(i) == size);
|
||||
}
|
||||
|
||||
auto m = input1.get_axis_size(-2);
|
||||
auto n = input1.get_axis_size(-1);
|
||||
auto k = input2.get_axis_size(-1);
|
||||
auto b = input1.size() / m / n;
|
||||
CV_Assert(input2.get_axis_size(-2) == n);
|
||||
CV_Assert(output.get_axis_size(-2) == m);
|
||||
CV_Assert(output.get_axis_size(-1) == k);
|
||||
|
||||
if (get_effective_rank(output) <= 2)
|
||||
{
|
||||
CV_Assert(b == 1);
|
||||
CV_Assert(get_effective_rank(input1) <= 2);
|
||||
CV_Assert(get_effective_rank(input2) <= 2);
|
||||
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Assert(rank >= 3);
|
||||
input1.reshape(b, m, n);
|
||||
input2.reshape(b, n, k);
|
||||
output.reshape(b, m, k);
|
||||
input1.squeeze_to(3);
|
||||
input2.squeeze_to(3);
|
||||
output.squeeze_to(3);
|
||||
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
csl::cublas::Handle cublasHandle;
|
||||
};
|
||||
|
||||
}}} /* namespace cv::dnn::cuda4dnn */
|
||||
|
||||
#endif /* OPENCV_DNN_SRC_CUDA4DNN_PRIMITIVES_MATMUL_HPP */
|
@ -55,6 +55,7 @@ using namespace cv::dnn::ocl4dnn;
|
||||
#endif
|
||||
|
||||
#ifdef HAVE_CUDA
|
||||
#include "../cuda4dnn/primitives/matmul.hpp"
|
||||
#include "../cuda4dnn/primitives/inner_product.hpp"
|
||||
using namespace cv::dnn::cuda4dnn;
|
||||
#endif
|
||||
@ -523,10 +524,14 @@ public:
|
||||
{
|
||||
auto context = reinterpret_cast<csl::CSLContext*>(context_);
|
||||
|
||||
if (weightsMat.empty())
|
||||
{
|
||||
CV_Assert(!bias);
|
||||
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle));
|
||||
}
|
||||
|
||||
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||
|
||||
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
|
||||
|
||||
auto biasMat_ = bias ? biasMat : Mat();
|
||||
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
|
||||
}
|
||||
|
@ -485,8 +485,6 @@ TEST_P(Test_ONNX_layers, MatMul)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
if (backend == DNN_BACKEND_CUDA)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA); // not supported
|
||||
|
||||
testONNXModels("matmul_2d");
|
||||
testONNXModels("matmul_3d");
|
||||
@ -735,8 +733,6 @@ TEST_P(Test_ONNX_layers, MatmulWithTwoInputs)
|
||||
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2020040000)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE);
|
||||
#endif
|
||||
if (backend == DNN_BACKEND_CUDA)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_CUDA);
|
||||
testONNXModels("matmul_with_two_inputs");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user