mirror of
https://github.com/opencv/opencv.git
synced 2025-06-08 01:53:19 +08:00
Fix TF Split layer
This commit is contained in:
parent
e4e0bb533d
commit
0d2bc7b5fd
@ -366,6 +366,7 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
|
|||||||
*/
|
*/
|
||||||
std::vector<std::vector<Range> > sliceRanges;
|
std::vector<std::vector<Range> > sliceRanges;
|
||||||
int axis;
|
int axis;
|
||||||
|
int num_split;
|
||||||
|
|
||||||
static Ptr<SliceLayer> create(const LayerParams ¶ms);
|
static Ptr<SliceLayer> create(const LayerParams ¶ms);
|
||||||
};
|
};
|
||||||
|
@ -61,6 +61,7 @@ public:
|
|||||||
{
|
{
|
||||||
setParamsFrom(params);
|
setParamsFrom(params);
|
||||||
axis = params.get<int>("axis", 1);
|
axis = params.get<int>("axis", 1);
|
||||||
|
num_split = params.get<int>("num_split", 0);
|
||||||
if (params.has("slice_point"))
|
if (params.has("slice_point"))
|
||||||
{
|
{
|
||||||
CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end"));
|
CV_Assert(!params.has("begin") && !params.has("size") && !params.has("end"));
|
||||||
@ -141,9 +142,10 @@ public:
|
|||||||
else // Divide input blob on equal parts by axis.
|
else // Divide input blob on equal parts by axis.
|
||||||
{
|
{
|
||||||
CV_Assert(0 <= axis && axis < inpShape.size());
|
CV_Assert(0 <= axis && axis < inpShape.size());
|
||||||
CV_Assert(requiredOutputs > 0 && inpShape[axis] % requiredOutputs == 0);
|
int splits = num_split ? num_split : requiredOutputs;
|
||||||
inpShape[axis] /= requiredOutputs;
|
CV_Assert(splits > 0 && inpShape[axis] % splits == 0);
|
||||||
outputs.resize(requiredOutputs, inpShape);
|
inpShape[axis] /= splits;
|
||||||
|
outputs.resize(splits, inpShape);
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -1410,6 +1410,9 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
axis = toNCHW(axis);
|
axis = toNCHW(axis);
|
||||||
layerParams.set("axis", axis);
|
layerParams.set("axis", axis);
|
||||||
|
|
||||||
|
if (hasLayerAttr(layer, "num_split"))
|
||||||
|
layerParams.set("num_split", getLayerAttr(layer, "num_split").i());
|
||||||
|
|
||||||
int id = dstNet.addLayer(name, "Slice", layerParams);
|
int id = dstNet.addLayer(name, "Slice", layerParams);
|
||||||
layer_id[name] = id;
|
layer_id[name] = id;
|
||||||
|
|
||||||
|
@ -350,6 +350,11 @@ TEST_P(Test_TensorFlow_layers, l2_normalize_3d)
|
|||||||
runTensorFlowNet("l2_normalize_3d");
|
runTensorFlowNet("l2_normalize_3d");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(Test_TensorFlow_layers, Split)
|
||||||
|
{
|
||||||
|
runTensorFlowNet("split");
|
||||||
|
}
|
||||||
|
|
||||||
class Test_TensorFlow_nets : public DNNTestLayer {};
|
class Test_TensorFlow_nets : public DNNTestLayer {};
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_nets, MobileNet_SSD)
|
TEST_P(Test_TensorFlow_nets, MobileNet_SSD)
|
||||||
|
Loading…
Reference in New Issue
Block a user