Merge pull request #23419 from dkurt:onnx_fixes

Several fixes for ONNX importer: Expand, Gather
This commit is contained in:
Alexander Smorkalov 2023-04-02 11:40:56 +03:00 committed by GitHub
commit d8c80ff5a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 4 deletions

View File

@ -2435,12 +2435,18 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
}
else
{
inpShape = shape(getBlob(input0));
Mat blob = getBlob(input0);
if (constBlobsExtraInfo.find(node_proto.input(0)) != constBlobsExtraInfo.end() &&
getBlobExtraInfo(node_proto, 0).real_ndims == 1) {
inpShape = {(int)blob.total()};
} else {
inpShape = shape(blob);
}
}
String srcName = input0;
// Unsqueeze and repeat along new axis
if (targetShape.size() == inpShape.size() + 1)
if (targetShape.size() > inpShape.size())
{
inpShape.insert(inpShape.begin(), targetShape.size() - inpShape.size(), 1);
for (int i = 0; i < targetShape.size(); i++)
@ -2486,7 +2492,7 @@ void ONNXImporter::parseExpand(LayerParams& layerParams, const opencv_onnx::Node
{
if (broadcast_axes.empty())
{
addConstant(output_name, getBlob(node_proto, 0));
addConstant(output_name, getBlob(node_proto, 0).reshape(1, targetShape));
return;
}
@ -2719,7 +2725,8 @@ void ONNXImporter::parseGather(LayerParams& layerParams, const opencv_onnx::Node
runLayer(layerParams, inputs, output);
output.back().convertTo(output.back(), type);
output.back().dims = std::max(input_real_ndims - real_ndims, 1);
if (real_ndims < 2) // In case of scalars or 1D vectors, OpenCV initializes 2D cv::Mat
output.back().dims = std::max(input_real_ndims - real_ndims, 1);
addConstant(node_proto.output(0), output.back());
return;
}

View File

@ -2487,6 +2487,11 @@ TEST_P(Test_ONNX_layers, Gelu)
testONNXModels("gelu_approximation");
}
TEST_P(Test_ONNX_layers, OpenAI_CLIP_head)
{
testONNXModels("clip-vit-base-head");
}
INSTANTIATE_TEST_CASE_P(/**/, Test_ONNX_nets, dnnBackendsAndTargets());
}} // namespace