mirror of
https://github.com/opencv/opencv.git
synced 2025-01-19 23:19:23 +08:00
Merge pull request #22148 from zihaomu:gemm_onnx_bug_fix_branch34
This commit is contained in:
commit
6234f01a6d
@ -1759,15 +1759,15 @@ void ONNXImporter::parseBatchNormalization(LayerParams& layerParams, const openc
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
// A * B + C = Y, we require that the dimension of A is [m, k], and the dimension of B is [n, k].
|
||||
// And the dim of output Y is [m, n]
|
||||
void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
CV_Assert(node_proto.input_size() >= 2);
|
||||
layerParams.type = "InnerProduct";
|
||||
Mat weights = getBlob(node_proto, 1);
|
||||
int ind_num_out = 0;
|
||||
if (layerParams.has("transB") && !layerParams.get<int>("transB")) {
|
||||
if (!layerParams.get<int>("transB", 0)) {
|
||||
transpose(weights, weights);
|
||||
ind_num_out = 1;
|
||||
}
|
||||
layerParams.blobs.push_back(weights);
|
||||
|
||||
@ -1789,7 +1789,7 @@ void ONNXImporter::parseGemm(LayerParams& layerParams, const opencv_onnx::NodePr
|
||||
addLayer(constParams, proto);
|
||||
}
|
||||
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[ind_num_out]);
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
||||
layerParams.set("bias_term", node_proto.input_size() == 3);
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
@ -1389,6 +1389,12 @@ TEST_P(Test_ONNX_layers, DivConst)
|
||||
testONNXModels("div_const");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Gemm)
|
||||
{
|
||||
testONNXModels("gemm_no_transB");
|
||||
testONNXModels("gemm_transB_0");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, OutputRegistration)
|
||||
{
|
||||
testONNXModels("output_registration", npy, 0, 0, false, true, 2);
|
||||
|
Loading…
Reference in New Issue
Block a user