mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 11:45:30 +08:00
Merge pull request #22828 from WanliZhong:improve_matmul
DNN: make MatMul support 3D or 4D with broadcast
This commit is contained in:
commit
ac6fb17784
@ -23,9 +23,14 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
public:
|
||||
using wrapper_type = GetCUDABackendWrapperType<T>;
|
||||
|
||||
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle)
|
||||
MatMulOp(csl::Stream stream_, csl::cublas::Handle handle, const Mat& constInp)
|
||||
: stream(std::move(stream_)), cublasHandle(std::move(handle))
|
||||
{
|
||||
if (!constInp.empty())
|
||||
{
|
||||
constTensor = csl::makeTensorHeader<T>(constInp);
|
||||
csl::copyMatToTensor<T>(constInp, constTensor, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void forward(
|
||||
@ -33,13 +38,20 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
const std::vector<cv::Ptr<BackendWrapper>>& outputs,
|
||||
csl::Workspace& workspace) override
|
||||
{
|
||||
CV_Assert(inputs.size() == 2 && outputs.size() == 1);
|
||||
CV_Assert((inputs.size() == 2 && constTensor.empty() ||
|
||||
inputs.size() == 1 && !constTensor.empty()) && outputs.size() == 1);
|
||||
|
||||
auto input1_wrapper = inputs[0].dynamicCast<wrapper_type>();
|
||||
auto input1 = input1_wrapper->getView();
|
||||
|
||||
auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
|
||||
auto input2 = input2_wrapper->getView();
|
||||
csl::TensorView<T> input2;
|
||||
if (constTensor.empty())
|
||||
{
|
||||
auto input2_wrapper = inputs[1].dynamicCast<wrapper_type>();
|
||||
input2 = input2_wrapper->getView();
|
||||
}
|
||||
else
|
||||
input2 = csl::TensorView<T>(constTensor);
|
||||
|
||||
auto output_wrapper = outputs[0].dynamicCast<wrapper_type>();
|
||||
auto output = output_wrapper->getSpan();
|
||||
@ -59,9 +71,18 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
|
||||
auto m = input1.get_axis_size(-2);
|
||||
auto n = input1.get_axis_size(-1);
|
||||
auto k = input2.get_axis_size(-1);
|
||||
auto b = input1.size() / m / n;
|
||||
CV_Assert(input2.get_axis_size(-2) == n);
|
||||
int k;
|
||||
if (constTensor.empty())
|
||||
{
|
||||
k = input2.get_axis_size(-1);
|
||||
CV_Assert(input2.get_axis_size(-2) == n);
|
||||
}
|
||||
else
|
||||
{
|
||||
k = input2.get_axis_size(-2);
|
||||
CV_Assert(input2.get_axis_size(-1) == n);
|
||||
}
|
||||
CV_Assert(output.get_axis_size(-2) == m);
|
||||
CV_Assert(output.get_axis_size(-1) == k);
|
||||
|
||||
@ -70,24 +91,28 @@ namespace cv { namespace dnn { namespace cuda4dnn {
|
||||
CV_Assert(b == 1);
|
||||
CV_Assert(get_effective_rank(input1) <= 2);
|
||||
CV_Assert(get_effective_rank(input2) <= 2);
|
||||
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
|
||||
csl::tensor_ops::gemm<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Assert(rank >= 3);
|
||||
input1.reshape(b, m, n);
|
||||
input2.reshape(b, n, k);
|
||||
if (constTensor.empty())
|
||||
input2.reshape(b, n, k);
|
||||
else
|
||||
input2.reshape(b, k, n);
|
||||
output.reshape(b, m, k);
|
||||
input1.squeeze_to(3);
|
||||
input2.squeeze_to(3);
|
||||
output.squeeze_to(3);
|
||||
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, false, input2);
|
||||
csl::tensor_ops::gemmStridedBatched<T>(cublasHandle, 0.0, output, 1.0, false, input1, !constTensor.empty(), input2);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
csl::Stream stream;
|
||||
csl::cublas::Handle cublasHandle;
|
||||
csl::Tensor<T> constTensor;
|
||||
};
|
||||
|
||||
}}} /* namespace cv::dnn::cuda4dnn */
|
||||
|
@ -85,6 +85,7 @@ public:
|
||||
|
||||
bias = params.get<bool>("bias_term", true);
|
||||
axis = params.get<int>("axis", 1);
|
||||
isMatMul = params.get<bool>("is_matmul", false);
|
||||
if (!blobs.empty())
|
||||
{
|
||||
CV_Assert(1 <= blobs.size() && blobs.size() <= 2);
|
||||
@ -94,6 +95,7 @@ public:
|
||||
CV_Assert(blobs[0].dims >= 2 && (size_t)(innerSize * numOutput) == blobs[0].total());
|
||||
CV_Assert(!bias || (blobs.size() == 2 && (size_t)numOutput == blobs[1].total()));
|
||||
|
||||
blobs[0].copyTo(oriMat);
|
||||
weightsMat = blobs[0] = blobs[0].reshape(1, numOutput);
|
||||
int vecsize = weightsMat.cols;
|
||||
if (vecsize % VEC_ALIGN != 0)
|
||||
@ -108,6 +110,8 @@ public:
|
||||
|
||||
if (bias)
|
||||
biasMat = blobs[1] = blobs[1].reshape(1, 1);
|
||||
else if(isMatMul)
|
||||
biasMat = Mat::zeros(1, oriMat.size[oriMat.dims - 2], weightsMat.type());
|
||||
else
|
||||
biasMat = Mat::zeros(1, numOutput, weightsMat.type());
|
||||
}
|
||||
@ -153,7 +157,10 @@ public:
|
||||
CV_Assert(!transA && !transB);
|
||||
CV_CheckEQ(inputsTmp.size(), (size_t)1, "");
|
||||
CV_CheckEQ(blobs[0].dims, 2, "");
|
||||
numOutput = blobs[0].size[0];
|
||||
if(isMatMul)
|
||||
numOutput = oriMat.size[oriMat.dims - 2];
|
||||
else
|
||||
numOutput = blobs[0].size[0];
|
||||
CV_Assert(!bias || (size_t)numOutput == blobs[1].total());
|
||||
cAxis = normalize_axis(axis, inputsTmp[0]);
|
||||
}
|
||||
@ -519,16 +526,40 @@ public:
|
||||
if (!blobs.empty())
|
||||
{
|
||||
CV_Assert(!transA && !transB);
|
||||
int axisCan = normalize_axis(axis, input[0].dims);
|
||||
int outerSize = input[0].total(0, axisCan);
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++)
|
||||
int inp1Dim = input[0].dims;
|
||||
if (isMatMul)
|
||||
{
|
||||
Mat srcMat = input[i].reshape(1, outerSize);
|
||||
Mat dstMat = output[i].reshape(1, outerSize);
|
||||
int matNum = input[0].total(0, inp1Dim - 2);
|
||||
int rowMatMul = oriMat.size[oriMat.dims - 2];
|
||||
Mat srcMatTmp = input[0].reshape(1, matNum);
|
||||
Mat dstMatTmp = output[0].reshape(1, matNum);
|
||||
|
||||
const int nstripes = getNumThreads();
|
||||
FullyConnected::run(srcMat, weightsMat, biasMat, dstMat, activ.get(), nstripes);
|
||||
int outerSize = input[0].size[inp1Dim - 2];
|
||||
int rowStart = -rowMatMul;
|
||||
for (int n = 0; n < matNum; ++n)
|
||||
{
|
||||
Mat srcMat = srcMatTmp.row(n).reshape(1, outerSize);
|
||||
Mat dstMat = dstMatTmp.row(n).reshape(1, outerSize);
|
||||
rowStart = (rowStart + rowMatMul) % weightsMat.rows;
|
||||
Mat weiMat = weightsMat.rowRange(rowStart, rowStart + rowMatMul);
|
||||
|
||||
const int nstripes = getNumThreads();
|
||||
FullyConnected::run(srcMat, weiMat, biasMat, dstMat, activ.get(), nstripes);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
int axisCan = normalize_axis(axis, inp1Dim);
|
||||
int outerSize = input[0].total(0, axisCan);
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++)
|
||||
{
|
||||
Mat srcMat = input[i].reshape(1, outerSize);
|
||||
Mat dstMat = output[i].reshape(1, outerSize);
|
||||
|
||||
const int nstripes = getNumThreads();
|
||||
FullyConnected::run(srcMat, weightsMat, biasMat, dstMat, activ.get(), nstripes);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
@ -579,14 +610,26 @@ public:
|
||||
) override
|
||||
{
|
||||
auto context = reinterpret_cast<csl::CSLContext*>(context_);
|
||||
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||
|
||||
if (weightsMat.empty())
|
||||
if (weightsMat.empty() || isMatMul)
|
||||
{
|
||||
CV_Assert(!bias);
|
||||
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle));
|
||||
int inp2Dim;
|
||||
// broadcast is not supported with CUDA
|
||||
if(weightsMat.empty())
|
||||
{
|
||||
auto input_wrapper2 = inputs[1].dynamicCast<CUDABackendWrapper>();
|
||||
inp2Dim = input_wrapper2->getRank();
|
||||
}else
|
||||
inp2Dim = oriMat.dims;
|
||||
|
||||
if(input_wrapper->getRank() == inp2Dim)
|
||||
return make_cuda_node<cuda4dnn::MatMulOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), oriMat);
|
||||
else
|
||||
return Ptr<BackendNode>();
|
||||
}
|
||||
|
||||
auto input_wrapper = inputs[0].dynamicCast<CUDABackendWrapper>();
|
||||
auto flatten_start_axis = normalize_axis(axis, input_wrapper->getRank());
|
||||
auto biasMat_ = bias ? biasMat : Mat();
|
||||
return make_cuda_node<cuda4dnn::InnerProductOp>(preferableTarget, std::move(context->stream), std::move(context->cublas_handle), flatten_start_axis, weightsMat, biasMat_);
|
||||
@ -752,8 +795,9 @@ public:
|
||||
}
|
||||
|
||||
bool bias;
|
||||
Mat weightsMat, biasMat;
|
||||
Mat weightsMat, biasMat, oriMat;
|
||||
bool transA, transB;
|
||||
bool isMatMul = false;
|
||||
Ptr<ActivationLayer> activ;
|
||||
};
|
||||
|
||||
|
@ -2088,30 +2088,21 @@ void ONNXImporter::parseMatMul(LayerParams& layerParams, const opencv_onnx::Node
|
||||
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
||||
{
|
||||
Mat blob = getBlob(node_proto, 1);
|
||||
Mat transBlob;
|
||||
secondInpDims = blob.dims;
|
||||
if (secondInpDims == 2)
|
||||
{
|
||||
layerParams.blobs.push_back(blob.t());
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
LayerParams constParams;
|
||||
constParams.name = layerParams.name + "/const_1";
|
||||
constParams.type = "Const";
|
||||
constParams.blobs.push_back(blob);
|
||||
|
||||
opencv_onnx::NodeProto tmpProto;
|
||||
tmpProto.add_output(constParams.name);
|
||||
addLayer(constParams, tmpProto);
|
||||
|
||||
node_proto.set_input(1, constParams.name);
|
||||
}
|
||||
}
|
||||
else
|
||||
// create order transposing last 2 dimensions
|
||||
std::vector<int> order(secondInpDims);
|
||||
std::iota(order.begin(), order.end(), 0);
|
||||
std::swap(order[secondInpDims - 2], order[secondInpDims - 1]);
|
||||
transposeND(blob, order, transBlob);
|
||||
layerParams.blobs.push_back(transBlob);
|
||||
int numOutput = layerParams.blobs[0].total(0, secondInpDims - 1);
|
||||
layerParams.set("num_output", numOutput);
|
||||
layerParams.set("is_matmul", true);
|
||||
} else
|
||||
secondInpDims = outShapes[node_proto.input(1)].size();
|
||||
|
||||
layerParams.set("axis", firstInpDims - secondInpDims + 1);
|
||||
layerParams.set("axis", firstInpDims - 1);
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
@ -921,6 +921,7 @@ TEST_P(Test_ONNX_layers, MatMul_init)
|
||||
testONNXModels("matmul_4d_init");
|
||||
|
||||
testONNXModels("matmul_init_2");
|
||||
testONNXModels("matmul_init_bcast");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, MatMulAdd)
|
||||
|
Loading…
Reference in New Issue
Block a user