mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #24812 from Abdurrahheem:ash/einsum_bachedGemm
Replace interactive batched Matrix Multiply. #24812 This PR replaces iterative batch matrix multiplication which `FastGemmBatch` in Einsum layer. ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
1e190b3094
commit
c923c59833
@ -385,7 +385,7 @@ void fastGemmBatch(bool trans_a, bool trans_b,
|
||||
const auto shape_b = shape(B);
|
||||
const auto shape_c = shape(C);
|
||||
CV_CheckGE(shape_a.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: A must be n-dimensional (n >= 2)");
|
||||
CV_CheckEQ(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)");
|
||||
CV_CheckGE(shape_b.size(), static_cast<size_t>(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)");
|
||||
|
||||
const float *a = A.ptr<const float>();
|
||||
const float *b = B.ptr<const float>();
|
||||
|
@ -1299,7 +1299,6 @@ Mat LayerEinsumImpl::batchwiseMatMul(
|
||||
const Mat& input2,
|
||||
const MatShape& input2ShapeOverride)
|
||||
{
|
||||
|
||||
// Sanity checks before the actual MatMul
|
||||
CV_CheckType(input1.type(), input2.type(), "Data types of the inputs must match for MatMul");
|
||||
CV_CheckEQ(input1ShapeOverride.size(), (size_t) 3, "Only 1 batch dimension is allowed for MatMul");
|
||||
@ -1312,61 +1311,22 @@ Mat LayerEinsumImpl::batchwiseMatMul(
|
||||
int K = input1ShapeOverride[2];
|
||||
int N = input2ShapeOverride[2];
|
||||
|
||||
std::vector<Mat> output;
|
||||
Mat reshapedInput1 = input1;
|
||||
Mat reshapedInput2 = input2;
|
||||
|
||||
Mat output;
|
||||
if (batches > 1)
|
||||
{
|
||||
Mat reshapedInput1 = input1;
|
||||
Mat reshapedInput2 = input2;
|
||||
// create tmpout with type like input1
|
||||
output = Mat({batches, M, N}, input1.type());
|
||||
|
||||
// input1 should of size MxK
|
||||
// check if input1 needs reshape, if need reshape
|
||||
if (input1.size[0] != M || input1.size[1] != K)
|
||||
{
|
||||
int shape[] = {batches, M, K};
|
||||
reshapedInput1 = input1.reshape(1, 3, shape);
|
||||
}
|
||||
|
||||
// input2 should be of size KxN
|
||||
// check if input2 needs reshape, if needs reshape
|
||||
if (input2.size[0] != K || input2.size[1] != N)
|
||||
{
|
||||
int shape[] = {batches, K, N};
|
||||
reshapedInput2 = input2.reshape(1, 3, shape);
|
||||
}
|
||||
|
||||
for (size_t i=0; i < batches; i++)
|
||||
{
|
||||
std::vector<Range> ranges1 = {cv::Range(i, i+1)};
|
||||
for (int j = 1; j < reshapedInput1.dims; j++)
|
||||
ranges1.emplace_back(cv::Range::all());
|
||||
|
||||
Mat part1 = reshapedInput1(ranges1);
|
||||
int shape[] = {M, K};
|
||||
part1 = part1.reshape(1, sizeof(shape)/sizeof(shape[0]), shape);
|
||||
|
||||
std::vector<Range> ranges2 = {cv::Range(i, i+1)};
|
||||
for (int j = 1; j < reshapedInput2.dims; j++)
|
||||
ranges2.emplace_back(cv::Range::all());
|
||||
|
||||
Mat part2 = reshapedInput2(ranges2);
|
||||
int shape2[] = {K, N};
|
||||
part2 = part2.reshape(1, sizeof(shape2)/sizeof(shape2[0]), shape2);
|
||||
|
||||
Mat tmp_output(M, N, part1.type());
|
||||
fastGemm(false, false, 1.0, part1, part2, 0.0, tmp_output, opt);
|
||||
int newShape[] = {1, M, N};
|
||||
tmp_output = tmp_output.reshape(1, sizeof(newShape)/sizeof(newShape[0]), newShape);
|
||||
|
||||
output.emplace_back(tmp_output);
|
||||
}
|
||||
reshapedInput2 = reshapedInput2.reshape(1, input2ShapeOverride);
|
||||
reshapedInput1 = reshapedInput1.reshape(1, input1ShapeOverride);
|
||||
|
||||
fastGemmBatch(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
|
||||
} else {
|
||||
|
||||
Mat reshapedInput1 = input1;
|
||||
Mat reshapedInput2 = input2;
|
||||
|
||||
// input1 should of size MxK
|
||||
// check if input1 needs reshape, if need reshape
|
||||
if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K)
|
||||
{
|
||||
int shape[] = {M, K};
|
||||
@ -1374,30 +1334,18 @@ Mat LayerEinsumImpl::batchwiseMatMul(
|
||||
}
|
||||
|
||||
// input2 should be of size KxN
|
||||
// check if input2 needs reshape, if needs reshape
|
||||
if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N)
|
||||
{
|
||||
int shape2[] = {K, N};
|
||||
reshapedInput2 = input2.reshape(1, 2, shape2);
|
||||
}
|
||||
|
||||
Mat tmp_output(M, N, reshapedInput1.type());
|
||||
fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, tmp_output, opt);
|
||||
|
||||
int newShape[] = {1, M, N};
|
||||
tmp_output = tmp_output.reshape(1, sizeof(newShape)/sizeof(newShape[0]), newShape);
|
||||
output.emplace_back(tmp_output);
|
||||
output = Mat(M, N, reshapedInput1.type());
|
||||
fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
|
||||
|
||||
output = output.reshape(1, {1, M, N});
|
||||
}
|
||||
|
||||
int outputDim[] = {static_cast<int>(output.size()), M, N};
|
||||
Mat output_buffer = Mat::zeros(3, outputDim, CV_32F);
|
||||
|
||||
for (size_t i = 0; i < output.size(); i++) {
|
||||
Mat output_slice = output_buffer.row(i);
|
||||
output[i].copyTo(output_slice);
|
||||
}
|
||||
return output_buffer;
|
||||
return output;
|
||||
};
|
||||
Ptr<EinsumLayer> EinsumLayer::create(const LayerParams& params)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user