support the split node of onnx opset >= 13

This commit is contained in:
zihaomu 2023-04-11 16:18:50 +08:00
parent d2dbaa4cd1
commit 51281f8d69
2 changed files with 20 additions and 1 deletions

View File

@ -1467,6 +1467,10 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
int axis = layerParams.get<int>("axis", 0);
MatShape inpShape = outShapes[node_proto.input(0)];
axis = normalize_axis(axis, inpShape.size());
if (layerParams.has("split"))
{
DictValue splits = layerParams.get("split");
@ -1480,13 +1484,26 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP
}
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
}
else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute.
{
CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end());
Mat splitsBlob = getBlob(node_proto, 1);
int splitSize = splitsBlob.total();
std::vector<int> slicePoints(splitSize - 1, splitsBlob.at<int>(0));
for (int i = 1; i < splitSize - 1; ++i)
{
slicePoints[i] = slicePoints[i - 1] + splitsBlob.at<int>(i);
}
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
}
else
{
layerParams.set("num_split", node_proto.output_size());
}
int depth = layerParams.get<int>("depth", CV_32F);
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice";
layerParams.set("axis", layerParams.get<float>("axis", 0));
layerParams.set("axis", axis);
addLayer(layerParams, node_proto);
}

View File

@ -1149,6 +1149,8 @@ TEST_P(Test_ONNX_layers, Split)
testONNXModels("split_2");
testONNXModels("split_3");
testONNXModels("split_4");
testONNXModels("split_5");
testONNXModels("split_6");
testONNXModels("split_neg_axis");
}