mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 20:42:53 +08:00
Merge pull request #23482 from zihaomu:onnx_opset13_split
DNN: support the split node of onnx opset >= 13
This commit is contained in:
commit
aa17f881b1
@ -1467,6 +1467,10 @@ void ONNXImporter::parseSlice(LayerParams& layerParams, const opencv_onnx::NodeP
|
|||||||
|
|
||||||
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||||
{
|
{
|
||||||
|
int axis = layerParams.get<int>("axis", 0);
|
||||||
|
MatShape inpShape = outShapes[node_proto.input(0)];
|
||||||
|
axis = normalize_axis(axis, inpShape.size());
|
||||||
|
|
||||||
if (layerParams.has("split"))
|
if (layerParams.has("split"))
|
||||||
{
|
{
|
||||||
DictValue splits = layerParams.get("split");
|
DictValue splits = layerParams.get("split");
|
||||||
@ -1480,13 +1484,26 @@ void ONNXImporter::parseSplit(LayerParams& layerParams, const opencv_onnx::NodeP
|
|||||||
}
|
}
|
||||||
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
|
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
|
||||||
}
|
}
|
||||||
|
else if (node_proto.input_size() == 2) // opset >= 13, the split will be stored at the second input, instead of the attribute.
|
||||||
|
{
|
||||||
|
CV_Assert(constBlobs.find(node_proto.input(1)) != constBlobs.end());
|
||||||
|
Mat splitsBlob = getBlob(node_proto, 1);
|
||||||
|
int splitSize = splitsBlob.total();
|
||||||
|
|
||||||
|
std::vector<int> slicePoints(splitSize - 1, splitsBlob.at<int>(0));
|
||||||
|
for (int i = 1; i < splitSize - 1; ++i)
|
||||||
|
{
|
||||||
|
slicePoints[i] = slicePoints[i - 1] + splitsBlob.at<int>(i);
|
||||||
|
}
|
||||||
|
layerParams.set("slice_point", DictValue::arrayInt(&slicePoints[0], slicePoints.size()));
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
layerParams.set("num_split", node_proto.output_size());
|
layerParams.set("num_split", node_proto.output_size());
|
||||||
}
|
}
|
||||||
int depth = layerParams.get<int>("depth", CV_32F);
|
int depth = layerParams.get<int>("depth", CV_32F);
|
||||||
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice";
|
layerParams.type = (depth == CV_8S) ? "SliceInt8" : "Slice";
|
||||||
layerParams.set("axis", layerParams.get<float>("axis", 0));
|
layerParams.set("axis", axis);
|
||||||
addLayer(layerParams, node_proto);
|
addLayer(layerParams, node_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1149,6 +1149,8 @@ TEST_P(Test_ONNX_layers, Split)
|
|||||||
testONNXModels("split_2");
|
testONNXModels("split_2");
|
||||||
testONNXModels("split_3");
|
testONNXModels("split_3");
|
||||||
testONNXModels("split_4");
|
testONNXModels("split_4");
|
||||||
|
testONNXModels("split_5");
|
||||||
|
testONNXModels("split_6");
|
||||||
testONNXModels("split_neg_axis");
|
testONNXModels("split_neg_axis");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user