Merge pull request #9762 from dkurt:fix_tensorflow_split_layer

This commit is contained in:
Vadim Pisarevsky 2017-10-05 10:51:49 +00:00
commit 5f6ce6f4b0
3 changed files with 6 additions and 3 deletions

View File

@ -116,7 +116,7 @@ public:
} }
else // Divide input blob on equal parts by axis. else // Divide input blob on equal parts by axis.
{ {
CV_Assert(0 < axis && axis < inpShape.size()); CV_Assert(0 <= axis && axis < inpShape.size());
CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0); CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
inpShape[axis] /= requiredOutputs; inpShape[axis] /= requiredOutputs;
outputs.resize(requiredOutputs, inpShape); outputs.resize(requiredOutputs, inpShape);

View File

@ -866,8 +866,6 @@ void TFImporter::populateNet(Net dstNet)
CV_Assert(layer.input_size() == 2); CV_Assert(layer.input_size() == 2);
// num_split // num_split
// 1st blob is dims tensor // 1st blob is dims tensor
layerParams.set("slice_point", DictValue::arrayReal((double*)0, 0));
int axis = getConstBlob(layer, value_id, 0).int_val().Get(0); int axis = getConstBlob(layer, value_id, 0).int_val().Get(0);
layerParams.set("axis", toNCHW[axis]); layerParams.set("axis", toNCHW[axis]);

View File

@ -170,4 +170,9 @@ TEST(Test_TensorFlow, lstm)
runTensorFlowNet("lstm"); runTensorFlowNet("lstm");
} }
TEST(Test_TensorFlow, split)
{
runTensorFlowNet("split_equals");
}
} }