mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 03:30:34 +08:00
support the split node of onnx opset >= 13
This commit is contained in:
parent
d2dbaa4cd1
commit
51281f8d69
@ -1467,6 +1467,10 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
|
||||
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
int axis = layerParams.get<int>("axis", 0);
|
||||
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||
axis = normalize_axis(axis, inpShape.size());
|
||||
|
||||
if (layerParams.has("split"))
|
||||
{
|
||||
DictValue splits = layerParams.get("split");
|
||||
@ -1480,13 +1484,26 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP
|
||||
}
|
||||
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
|
||||
}
|
||||
else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute.
|
||||
{
|
||||
CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end());
|
||||
Mat splitsBlob = getBlob(node_proto, 1);
|
||||
int splitSize = splitsBlob.total();
|
||||
|
||||
std::vector<int> slicePoints(splitSize - 1, splitsBlob.at<int>(0));
|
||||
for (int i = 1; i < splitSize - 1; ++i)
|
||||
{
|
||||
slicePoints[i] = slicePoints[i - 1] + splitsBlob.at<int>(i);
|
||||
}
|
||||
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
|
||||
}
|
||||
else
|
||||
{
|
||||
layerParams.set("num_split", node_proto.output_size());
|
||||
}
|
||||
int depth = layerParams.get<int>("depth", CV_32F);
|
||||
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice";
|
||||
layerParams.set("axis", layerParams.get<float>("axis", 0));
|
||||
layerParams.set("axis", axis);
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
@ -1149,6 +1149,8 @@ TEST_P(Test_ONNX_layers, Split)
|
||||
testONNXModels("split_2");
|
||||
testONNXModels("split_3");
|
||||
testONNXModels("split_4");
|
||||
testONNXModels("split_5");
|
||||
testONNXModels("split_6");
|
||||
testONNXModels("split_neg_axis");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user