mirror of
https://github.com/opencv/opencv.git
synced 2025-06-07 17:44:04 +08:00
Merge pull request #9762 from dkurt:fix_tensorflow_split_layer
This commit is contained in:
commit
5f6ce6f4b0
@ -116,7 +116,7 @@ public:
|
|||||||
}
|
}
|
||||||
else // Divide input blob on equal parts by axis.
|
else // Divide input blob on equal parts by axis.
|
||||||
{
|
{
|
||||||
CV_Assert(0 < axis && axis < inpShape.size());
|
CV_Assert(0 <= axis && axis < inpShape.size());
|
||||||
CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
|
CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
|
||||||
inpShape[axis] /= requiredOutputs;
|
inpShape[axis] /= requiredOutputs;
|
||||||
outputs.resize(requiredOutputs, inpShape);
|
outputs.resize(requiredOutputs, inpShape);
|
||||||
|
@ -866,8 +866,6 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
CV_Assert(layer.input_size() == 2);
|
CV_Assert(layer.input_size() == 2);
|
||||||
// num_split
|
// num_split
|
||||||
// 1st blob is dims tensor
|
// 1st blob is dims tensor
|
||||||
layerParams.set("slice_point", DictValue::arrayReal((double*)0, 0));
|
|
||||||
|
|
||||||
int axis = getConstBlob(layer, value_id, 0).int_val().Get(0);
|
int axis = getConstBlob(layer, value_id, 0).int_val().Get(0);
|
||||||
layerParams.set("axis", toNCHW[axis]);
|
layerParams.set("axis", toNCHW[axis]);
|
||||||
|
|
||||||
|
@ -170,4 +170,9 @@ TEST(Test_TensorFlow, lstm)
|
|||||||
runTensorFlowNet("lstm");
|
runTensorFlowNet("lstm");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(Test_TensorFlow, split)
|
||||||
|
{
|
||||||
|
runTensorFlowNet("split_equals");
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user