mirror of
https://github.com/opencv/opencv.git
synced 2025-08-06 06:26:29 +08:00
Supported TF concat 3d
This commit is contained in:
parent
8badf7f354
commit
aa08900ac8
@ -46,6 +46,14 @@ static int toNCHW(int idx)
|
||||
else return (4 + idx) % 3 + 1;
|
||||
}
|
||||
|
||||
static int toNCDHW(int idx)
|
||||
{
|
||||
CV_Assert(-5 <= idx && idx < 5);
|
||||
if (idx == 0) return 0;
|
||||
else if (idx > 0) return idx % 4 + 1;
|
||||
else return (5 + idx) % 4 + 1;
|
||||
}
|
||||
|
||||
// This values are used to indicate layer output's data layout where it's possible.
|
||||
enum DataLayout
|
||||
{
|
||||
@ -1313,6 +1321,8 @@ void TFImporter::populateNet(Net dstNet)
|
||||
|
||||
if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NHWC)
|
||||
axis = toNCHW(axis);
|
||||
else if (getDataLayout(name, data_layouts) == DATA_LAYOUT_NDHWC)
|
||||
axis = toNCDHW(axis);
|
||||
layerParams.set("axis", axis);
|
||||
|
||||
// input(0) or input(n-1) is concat_dim
|
||||
|
@ -196,6 +196,7 @@ TEST_P(Test_TensorFlow_layers, pad_and_concat)
|
||||
TEST_P(Test_TensorFlow_layers, concat_axis_1)
|
||||
{
|
||||
runTensorFlowNet("concat_axis_1");
|
||||
runTensorFlowNet("concat_3d");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, batch_norm_1)
|
||||
|
Loading…
Reference in New Issue
Block a user