mirror of
https://github.com/opencv/opencv.git
synced 2024-11-24 03:00:14 +08:00
Merge pull request #18078 from l-bat:fix_matmul
This commit is contained in:
commit
3b5813c035
@ -641,6 +641,17 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
{
|
||||
layerParams.type = "Scale";
|
||||
layerParams.set("bias_term", true);
|
||||
int axis = 1;
|
||||
for (int i = 0; i < graph_proto.initializer_size(); i++)
|
||||
{
|
||||
opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i);
|
||||
if (tensor_proto.name() == node_proto.input(const_blob_id))
|
||||
{
|
||||
axis = inpShape.size() - tensor_proto.dims_size();
|
||||
break;
|
||||
}
|
||||
}
|
||||
layerParams.set("axis", axis);
|
||||
blob = blob.reshape(1, 1);
|
||||
layerParams.blobs.push_back((isSub ? -1 : 1) * blob);
|
||||
}
|
||||
@ -911,13 +922,20 @@ void ONNXImporter::populateNet(Net dstNet)
|
||||
CV_Assert(node_proto.input_size() == 2);
|
||||
layerParams.type = "InnerProduct";
|
||||
layerParams.set("bias_term", false);
|
||||
CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
|
||||
int firstInpDims = outShapes[node_proto.input(0)].size();
|
||||
int secondInpDims;
|
||||
|
||||
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
|
||||
{
|
||||
Mat blob = getBlob(node_proto, constBlobs, 1);
|
||||
secondInpDims = blob.dims;
|
||||
layerParams.blobs.push_back(blob.t());
|
||||
layerParams.set("num_output", layerParams.blobs[0].size[0]);
|
||||
} else {
|
||||
secondInpDims = outShapes[node_proto.input(1)].size();
|
||||
}
|
||||
layerParams.set("axis", firstInpDims - secondInpDims + 1);
|
||||
}
|
||||
else if (layer_type == "Mul" || layer_type == "Div")
|
||||
{
|
||||
|
@ -404,6 +404,15 @@ TEST_P(Test_ONNX_layers, MatMul)
|
||||
testONNXModels("matmul_4d");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, MatMulAdd)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
|
||||
testONNXModels("matmul_add");
|
||||
}
|
||||
|
||||
TEST_P(Test_ONNX_layers, Expand)
|
||||
{
|
||||
testONNXModels("expand_batch");
|
||||
|
Loading…
Reference in New Issue
Block a user