diff --git a/modules/dnn/src/layers/fully_connected_layer.cpp b/modules/dnn/src/layers/fully_connected_layer.cpp index 71ca706ac4..1be5bbe366 100644 --- a/modules/dnn/src/layers/fully_connected_layer.cpp +++ b/modules/dnn/src/layers/fully_connected_layer.cpp @@ -80,6 +80,9 @@ public: FullyConnectedLayerImpl(const LayerParams& params) { setParamsFrom(params); + transA = params.get("transA", false); + transB = params.get("transB", false); + bias = params.get("bias_term", true); axis = params.get("axis", 1); if (!blobs.empty()) @@ -116,30 +119,48 @@ public: std::vector &) const CV_OVERRIDE { int numOutput, cAxis; + + std::vector inputsTmp; + inputsTmp.assign(inputs.begin(), inputs.end()); + if (blobs.empty()) { - CV_CheckEQ(inputs.size(), (size_t)2, ""); - numOutput = inputs[1].back(); - cAxis = inputs[0].size() - 1; - int dims = inputs[0].size(); - CV_CheckEQ(inputs[1].size(), (size_t)dims, ""); + CV_CheckEQ(inputsTmp.size(), (size_t)2, ""); + + if (transA) + { + CV_CheckEQ(inputsTmp[0].size(), (size_t)2, ""); + std::swap(inputsTmp[0][0], inputsTmp[0][1]); + } + + if (transB) + { + CV_CheckEQ(inputsTmp[1].size(), (size_t)2, ""); + std::swap(inputsTmp[1][0], inputsTmp[1][1]); + } + + numOutput = inputsTmp[1].back(); + cAxis = inputsTmp[0].size() - 1; + int dims = inputsTmp[0].size(); + CV_CheckEQ(inputsTmp[1].size(), (size_t)dims, ""); CV_CheckGE(dims, 2, ""); for (int i = 0; i < dims - 2; i++) - CV_CheckEQ(inputs[0][i], inputs[1][i], ""); - CV_CheckEQ(inputs[0].back(), inputs[1][dims - 2], ""); + CV_CheckEQ(inputsTmp[0][i], inputsTmp[1][i], ""); + CV_CheckEQ(inputsTmp[0].back(), inputsTmp[1][dims - 2], ""); } else { - CV_CheckEQ(inputs.size(), (size_t)1, ""); + CV_Assert(!transA && !transB); + CV_CheckEQ(inputsTmp.size(), (size_t)1, ""); CV_CheckEQ(blobs[0].dims, 2, ""); numOutput = blobs[0].size[0]; CV_Assert(!bias || (size_t)numOutput == blobs[1].total()); - cAxis = normalize_axis(axis, inputs[0]); + cAxis = normalize_axis(axis, inputsTmp[0]); } MatShape outShape(cAxis + 1); for (int i = 0; i < cAxis; ++i) - outShape[i] = inputs[0][i]; + outShape[i] = inputsTmp[0][i]; outShape.back() = numOutput; outputs.resize(1, outShape); @@ -148,15 +169,15 @@ public: virtual bool supportBackend(int backendId) CV_OVERRIDE { + bool tranAorB = transA || transB; #ifdef HAVE_INF_ENGINE if (backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH) - return axis == 1; + return axis == 1 && !tranAorB; #endif - return backendId == DNN_BACKEND_OPENCV || - backendId == DNN_BACKEND_CUDA || - (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1) || - (backendId == DNN_BACKEND_WEBNN && axis == 1); + (backendId == DNN_BACKEND_CUDA && !tranAorB) || + (backendId == DNN_BACKEND_HALIDE && haveHalide() && axis == 1 && !tranAorB) || + (backendId == DNN_BACKEND_WEBNN && axis == 1 && !tranAorB); } virtual bool setActivation(const Ptr& layer) CV_OVERRIDE @@ -497,6 +518,7 @@ public: if (!blobs.empty()) { + CV_Assert(!transA && !transB); int axisCan = normalize_axis(axis, input[0].dims); int outerSize = input[0].total(0, axisCan); @@ -511,15 +533,30 @@ public: } else { - float* inpData = input[0].ptr(); - float* weightData = input[1].ptr(); + Mat input0 = input[0]; + Mat input1 = input[1]; + + if (transA) + { + CV_Assert(input0.dims == 2); + input0 = input0.t(); + } + + if (transB) + { + CV_Assert(input1.dims == 2); + input1 = input1.t(); + } + + float* inpData = input0.ptr(); + float* weightData = input1.ptr(); float* outData = output[0].ptr(); int dims = output[0].dims; int numSlice = output[0].total() / output[0].total(dims - 2); - int m = input[0].size[dims - 2]; - int n = input[0].size[dims - 1]; - int k = input[1].size[dims - 1]; + int m = input0.size[dims - 2]; + int n = input0.size[dims - 1]; + int k = input1.size[dims - 1]; for (int i = 0; i < numSlice; i++) { Mat inpSlice(m, n, CV_32F, inpData); @@ -716,6 +753,7 @@ public: bool bias; Mat weightsMat, biasMat; + bool transA, transB; Ptr activ; }; diff --git a/modules/dnn/src/onnx/onnx_importer.cpp b/modules/dnn/src/onnx/onnx_importer.cpp index 6ef8894063..c626a993fb 100644 --- a/modules/dnn/src/onnx/onnx_importer.cpp +++ b/modules/dnn/src/onnx/onnx_importer.cpp @@ -2000,18 +2000,9 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr { CV_Assert(node_proto.input_size() >= 2); layerParams.type = "InnerProduct"; - Mat weights = getBlob(node_proto, 1); + int transA = layerParams.get("transA", 0); + layerParams.set("transA", transA == 1); - if (!layerParams.get("transB", 0)) - { - transpose(weights, weights); - } - layerParams.blobs.push_back(weights); - - if (node_proto.input_size() == 3) { - Mat bias = getBlob(node_proto, 2); - layerParams.blobs.push_back(bias); - } if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) { Mat inputBuf = getBlob(node_proto, 0); @@ -2026,7 +2017,43 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr addLayer(constParams, proto); } - layerParams.set("num_output", layerParams.blobs[0].size[0]); + int transB = layerParams.get("transB", 0); + if (constBlobs.find(node_proto.input(1)) != constBlobs.end()) + { + Mat weights = getBlob(node_proto, 1); + + if (transA == 0) // optimized barnch, for now, we can only optimize the Gemm when transA = 0. + { + if (transB == 0) + { + transpose(weights, weights); + } + layerParams.set("transB", false); + layerParams.blobs.push_back(weights); + layerParams.set("num_output", layerParams.blobs[0].size[0]); + } + else // no optimized branch, TODO! optimize when the transA==1. + { + LayerParams constParams; + constParams.name = node_proto.input(1); + constParams.type = "Const"; + constParams.blobs.push_back(weights); + + opencv_onnx::NodeProto proto; + proto.add_output(constParams.name); + addLayer(constParams, proto); + layerParams.set("transB", transB == 1); + } + } + else + layerParams.set("transB", transB == 1); + + if (node_proto.input_size() == 3) + { + Mat bias = getBlob(node_proto, 2); + layerParams.blobs.push_back(bias); + } + layerParams.set("bias_term", node_proto.input_size() == 3); addLayer(layerParams, node_proto); } diff --git a/modules/dnn/test/test_onnx_importer.cpp b/modules/dnn/test/test_onnx_importer.cpp index 43dc817733..cee7cf023f 100644 --- a/modules/dnn/test/test_onnx_importer.cpp +++ b/modules/dnn/test/test_onnx_importer.cpp @@ -1829,6 +1829,7 @@ TEST_P(Test_ONNX_layers, Gemm) { testONNXModels("gemm_no_transB"); testONNXModels("gemm_transB_0"); + testONNXModels("gemm_first_const"); } TEST_P(Test_ONNX_layers, Quantized_Convolution)