diff --git a/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp b/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp index a8972aba4e..f8fe2bb40e 100644 --- a/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp +++ b/modules/dnn/src/layers/cpu_kernels/fast_gemm.cpp @@ -385,7 +385,7 @@ void fastGemmBatch(bool trans_a, bool trans_b, const auto shape_b = shape(B); const auto shape_c = shape(C); CV_CheckGE(shape_a.size(), static_cast(2), "DNN/fastGemmBatch: A must be n-dimensional (n >= 2)"); - CV_CheckEQ(shape_b.size(), static_cast(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)"); + CV_CheckGE(shape_b.size(), static_cast(2), "DNN/fastGemmBatch: B must be n-dimensional (n >= 2)"); const float *a = A.ptr(); const float *b = B.ptr(); diff --git a/modules/dnn/src/layers/einsum_layer.cpp b/modules/dnn/src/layers/einsum_layer.cpp index c7f9aaca06..d5153a5ab7 100644 --- a/modules/dnn/src/layers/einsum_layer.cpp +++ b/modules/dnn/src/layers/einsum_layer.cpp @@ -1299,7 +1299,6 @@ Mat LayerEinsumImpl::batchwiseMatMul( const Mat& input2, const MatShape& input2ShapeOverride) { - // Sanity checks before the actual MatMul CV_CheckType(input1.type(), input2.type(), "Data types of the inputs must match for MatMul"); CV_CheckEQ(input1ShapeOverride.size(), (size_t) 3, "Only 1 batch dimension is allowed for MatMul"); @@ -1312,61 +1311,22 @@ Mat LayerEinsumImpl::batchwiseMatMul( int K = input1ShapeOverride[2]; int N = input2ShapeOverride[2]; - std::vector output; + Mat reshapedInput1 = input1; + Mat reshapedInput2 = input2; + + Mat output; if (batches > 1) { - Mat reshapedInput1 = input1; - Mat reshapedInput2 = input2; + // create tmpout with type like input1 + output = Mat({batches, M, N}, input1.type()); - // input1 should of size MxK - // check if input1 needs reshape, if need reshape - if (input1.size[0] != M || input1.size[1] != K) - { - int shape[] = {batches, M, K}; - reshapedInput1 = input1.reshape(1, 3, shape); - } - - // input2 should be of size KxN - // check if input2 needs reshape, if needs reshape - if (input2.size[0] != K || input2.size[1] != N) - { - int shape[] = {batches, K, N}; - reshapedInput2 = input2.reshape(1, 3, shape); - } - - for (size_t i=0; i < batches; i++) - { - std::vector ranges1 = {cv::Range(i, i+1)}; - for (int j = 1; j < reshapedInput1.dims; j++) - ranges1.emplace_back(cv::Range::all()); - - Mat part1 = reshapedInput1(ranges1); - int shape[] = {M, K}; - part1 = part1.reshape(1, sizeof(shape)/sizeof(shape[0]), shape); - - std::vector ranges2 = {cv::Range(i, i+1)}; - for (int j = 1; j < reshapedInput2.dims; j++) - ranges2.emplace_back(cv::Range::all()); - - Mat part2 = reshapedInput2(ranges2); - int shape2[] = {K, N}; - part2 = part2.reshape(1, sizeof(shape2)/sizeof(shape2[0]), shape2); - - Mat tmp_output(M, N, part1.type()); - fastGemm(false, false, 1.0, part1, part2, 0.0, tmp_output, opt); - int newShape[] = {1, M, N}; - tmp_output = tmp_output.reshape(1, sizeof(newShape)/sizeof(newShape[0]), newShape); - - output.emplace_back(tmp_output); - } + reshapedInput2 = reshapedInput2.reshape(1, input2ShapeOverride); + reshapedInput1 = reshapedInput1.reshape(1, input1ShapeOverride); + fastGemmBatch(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt); } else { - Mat reshapedInput1 = input1; - Mat reshapedInput2 = input2; - // input1 should of size MxK - // check if input1 needs reshape, if need reshape if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K) { int shape[] = {M, K}; @@ -1374,30 +1334,18 @@ Mat LayerEinsumImpl::batchwiseMatMul( } // input2 should be of size KxN - // check if input2 needs reshape, if needs reshape if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N) { int shape2[] = {K, N}; reshapedInput2 = input2.reshape(1, 2, shape2); } - Mat tmp_output(M, N, reshapedInput1.type()); - fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, tmp_output, opt); - - int newShape[] = {1, M, N}; - tmp_output = tmp_output.reshape(1, sizeof(newShape)/sizeof(newShape[0]), newShape); - output.emplace_back(tmp_output); + output = Mat(M, N, reshapedInput1.type()); + fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt); + output = output.reshape(1, {1, M, N}); } - - int outputDim[] = {static_cast(output.size()), M, N}; - Mat output_buffer = Mat::zeros(3, outputDim, CV_32F); - - for (size_t i = 0; i < output.size(); i++) { - Mat output_slice = output_buffer.row(i); - output[i].copyTo(output_slice); - } - return output_buffer; + return output; }; Ptr EinsumLayer::create(const LayerParams& params) {