fix Flatten layer

This commit is contained in:
Smirnov Egor 2021-12-16 22:41:47 +03:00
parent f071207463
commit fec2c7e715
5 changed files with 58 additions and 36 deletions

View File

@ -100,7 +100,6 @@ public:
{ {
outputShapeVec.push_back(inputs[0][i]); outputShapeVec.push_back(inputs[0][i]);
} }
CV_Assert(outputShapeVec.size() <= 4);
outputs.resize(inputs.size(), outputShapeVec); outputs.resize(inputs.size(), outputShapeVec);

View File

@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto) void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
{ {
opencv_onnx::NodeProto node_proto = node_proto_;
CV_CheckEQ(node_proto.input_size(), 1, ""); CV_CheckEQ(node_proto.input_size(), 1, "");
int axis_ = layerParams.get<int>("axis", 1);
if (constBlobs.find(node_proto.input(0)) != constBlobs.end()) if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
{ {
Mat input = getBlob(node_proto, 0); Mat input = getBlob(node_proto, 0);
int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims); int axis = normalize_axis(axis_, input.dims);
std::vector<int> out_size(&input.size[0], &input.size[0] + axis); int out_size[2] = {1, 1};
out_size.push_back(input.total(axis)); for (int i = 0; i < axis; ++i)
Mat output = input.reshape(1, out_size); {
out_size[0] *= input.size[i];
}
for (int i = axis; i < input.dims; ++i)
{
out_size[1] *= input.size[i];
}
Mat output = input.reshape(1, 2, out_size);
addConstant(layerParams.name, output); addConstant(layerParams.name, output);
return; return;
} }
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
CV_Assert(shapeIt != outShapes.end());
MatShape inpShape = shapeIt->second;
int axis = normalize_axis(axis_, inpShape.size());
if (axis == 0 || axis == inpShape.size())
{
LayerParams reshapeLp;
reshapeLp.name = layerParams.name + "/reshape";
reshapeLp.type = "Reshape";
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1);
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(reshapeLp.name);
addLayer(reshapeLp, proto);
node_proto.set_input(0, reshapeLp.name);
axis += 1;
}
LayerParams first_pass;
first_pass.name = layerParams.name + "/flatten";
CV_Assert(layer_id.find(first_pass.name) == layer_id.end());
first_pass.type = "Flatten";
first_pass.set("axis", 0);
first_pass.set("end_axis", axis - 1);
opencv_onnx::NodeProto proto;
proto.add_input(node_proto.input(0));
proto.add_output(first_pass.name);
addLayer(first_pass, proto);
layerParams.set("axis", 1);
node_proto.set_input(0, first_pass.name);
addLayer(layerParams, node_proto); addLayer(layerParams, node_proto);
} }

View File

@ -17,12 +17,6 @@
"test_elu", "test_elu",
"test_elu_default", "test_elu_default",
"test_exp", "test_exp",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis4",
"test_leakyrelu", "test_leakyrelu",
"test_leakyrelu_default", "test_leakyrelu_default",
"test_logsoftmax_axis_1", "test_logsoftmax_axis_1",

View File

@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype)
CASE(test_eyelike_without_dtype) CASE(test_eyelike_without_dtype)
// no filter // no filter
CASE(test_flatten_axis0) CASE(test_flatten_axis0)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000) // no filter
SKIP;
#endif
CASE(test_flatten_axis1) CASE(test_flatten_axis1)
// no filter // no filter
CASE(test_flatten_axis2) CASE(test_flatten_axis2)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000) // no filter
SKIP;
#endif
CASE(test_flatten_axis3) CASE(test_flatten_axis3)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000) // no filter
SKIP;
#endif
CASE(test_flatten_default_axis) CASE(test_flatten_default_axis)
// no filter // no filter
CASE(test_flatten_negative_axis1) CASE(test_flatten_negative_axis1)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000) // no filter
SKIP;
#endif
CASE(test_flatten_negative_axis2) CASE(test_flatten_negative_axis2)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000) // no filter
SKIP;
#endif
CASE(test_flatten_negative_axis3) CASE(test_flatten_negative_axis3)
// no filter // no filter
CASE(test_flatten_negative_axis4) CASE(test_flatten_negative_axis4)
#if INF_ENGINE_VER_MAJOR_EQ(2021040000) // no filter
SKIP;
#endif
CASE(test_floor) CASE(test_floor)
// no filter // no filter
CASE(test_floor_example) CASE(test_floor_example)

View File

@ -7,12 +7,6 @@
"test_castlike_FLOAT_to_STRING_expanded", "test_castlike_FLOAT_to_STRING_expanded",
"test_castlike_STRING_to_FLOAT_expanded", "test_castlike_STRING_to_FLOAT_expanded",
"test_concat_1d_axis_negative_1", "test_concat_1d_axis_negative_1",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
"test_flatten_negative_axis1",
"test_flatten_negative_axis2",
"test_flatten_negative_axis4",
"test_logsoftmax_default_axis", "test_logsoftmax_default_axis",
"test_maxpool_2d_dilations", "test_maxpool_2d_dilations",
"test_maxpool_2d_same_lower", "test_maxpool_2d_same_lower",