From 6c28d7140a7275c9384f3b391b91b26b6b22445d Mon Sep 17 00:00:00 2001 From: Abduragim Date: Mon, 8 Jan 2024 21:34:47 +0300 Subject: [PATCH] 1d support for einsum --- modules/dnn/src/layers/einsum_layer.cpp | 4 ++-- modules/dnn/test/test_onnx_importer.cpp | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/dnn/src/layers/einsum_layer.cpp b/modules/dnn/src/layers/einsum_layer.cpp index baf4297c0e..3dee2057ab 100644 --- a/modules/dnn/src/layers/einsum_layer.cpp +++ b/modules/dnn/src/layers/einsum_layer.cpp @@ -105,7 +105,7 @@ static Mat batchwiseMatMul( // input1 should of size MxK // check if input1 needs reshape, if need reshape - if (input1.dims > 2 || input1.size[0] != M || input1.size[1] != K) + if (input1.dims > 2 || input1.size[0] != M || (input1.dims > 1 && input1.size[1] != K) || input1.dims == 1) { int shape[] = {static_cast(M), static_cast(K)}; reshapedInput1 = input1.reshape(1, 2, shape); @@ -113,7 +113,7 @@ static Mat batchwiseMatMul( // input2 should be of size KxN // check if input2 needs reshape, if needs reshape - if (input2.dims > 2 || input2.size[0] != K || input2.size[1] != N) + if (input2.dims > 2 || input2.size[0] != K || (input2.dims > 1 && input2.size[1] != N) || input2.dims == 1) { int shape2[] = {static_cast(K), static_cast(N)}; reshapedInput2 = input2.reshape(1, 2, shape2); diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 525da26ed5..c304a897d7 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1452,7 +1452,7 @@ TEST_P(Test_ONNX_layers, LSTM_layout_batch) testONNXModels("lstm_layout_1", npy, 0.005, 0.005, false, false, 3); } -TEST_P(Test_ONNX_layers, DISABLED_Einsum_1D) +TEST_P(Test_ONNX_layers, Einsum_1D) { testONNXModels("einsum_1d", npy, 0, 0, false, false, 2); } @@ -1482,12 +1482,12 @@ TEST_P(Test_ONNX_layers, Einsum_5D) testONNXModels("einsum_5d", npy, 0, 0, false, false, 2); } -TEST_P(Test_ONNX_layers, DISABLED_Einsum_InnerProduct) +TEST_P(Test_ONNX_layers, Einsum_InnerProduct) { testONNXModels("einsum_inner", npy, 0, 0, false, false, 2); } -TEST_P(Test_ONNX_layers, DISABLED_Einsum_HadamardProduct) +TEST_P(Test_ONNX_layers, Einsum_HadamardProduct) { testONNXModels("einsum_hadamard", npy, 0, 0, false, false, 2); }