mirror of
https://github.com/opencv/opencv.git
synced 2025-06-11 03:33:28 +08:00
fix Flatten layer
This commit is contained in:
parent
f071207463
commit
fec2c7e715
@ -100,7 +100,6 @@ public:
|
||||
{
|
||||
outputShapeVec.push_back(inputs[0][i]);
|
||||
}
|
||||
CV_Assert(outputShapeVec.size() <= 4);
|
||||
|
||||
outputs.resize(inputs.size(), outputShapeVec);
|
||||
|
||||
|
@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
|
||||
{
|
||||
opencv_onnx::NodeProto node_proto = node_proto_;
|
||||
CV_CheckEQ(node_proto.input_size(), 1, "");
|
||||
int axis_ = layerParams.get<int>("axis", 1);
|
||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||
{
|
||||
Mat input = getBlob(node_proto, 0);
|
||||
int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims);
|
||||
int axis = normalize_axis(axis_, input.dims);
|
||||
|
||||
std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
|
||||
out_size.push_back(input.total(axis));
|
||||
Mat output = input.reshape(1, out_size);
|
||||
int out_size[2] = {1, 1};
|
||||
for (int i = 0; i < axis; ++i)
|
||||
{
|
||||
out_size[0] *= input.size[i];
|
||||
}
|
||||
for (int i = axis; i < input.dims; ++i)
|
||||
{
|
||||
out_size[1] *= input.size[i];
|
||||
}
|
||||
|
||||
Mat output = input.reshape(1, 2, out_size);
|
||||
addConstant(layerParams.name, output);
|
||||
return;
|
||||
}
|
||||
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
|
||||
CV_Assert(shapeIt != outShapes.end());
|
||||
MatShape inpShape = shapeIt->second;
|
||||
int axis = normalize_axis(axis_, inpShape.size());
|
||||
|
||||
if (axis == 0 || axis == inpShape.size())
|
||||
{
|
||||
LayerParams reshapeLp;
|
||||
reshapeLp.name = layerParams.name + "/reshape";
|
||||
reshapeLp.type = "Reshape";
|
||||
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||
|
||||
inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1);
|
||||
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(reshapeLp.name);
|
||||
addLayer(reshapeLp, proto);
|
||||
node_proto.set_input(0, reshapeLp.name);
|
||||
axis += 1;
|
||||
}
|
||||
|
||||
LayerParams first_pass;
|
||||
first_pass.name = layerParams.name + "/flatten";
|
||||
CV_Assert(layer_id.find(first_pass.name) == layer_id.end());
|
||||
first_pass.type = "Flatten";
|
||||
first_pass.set("axis", 0);
|
||||
first_pass.set("end_axis", axis - 1);
|
||||
|
||||
opencv_onnx::NodeProto proto;
|
||||
proto.add_input(node_proto.input(0));
|
||||
proto.add_output(first_pass.name);
|
||||
addLayer(first_pass, proto);
|
||||
|
||||
layerParams.set("axis", 1);
|
||||
node_proto.set_input(0, first_pass.name);
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
@ -17,12 +17,6 @@
|
||||
"test_elu",
|
||||
"test_elu_default",
|
||||
"test_exp",
|
||||
"test_flatten_axis0",
|
||||
"test_flatten_axis2",
|
||||
"test_flatten_axis3",
|
||||
"test_flatten_negative_axis1",
|
||||
"test_flatten_negative_axis2",
|
||||
"test_flatten_negative_axis4",
|
||||
"test_leakyrelu",
|
||||
"test_leakyrelu_default",
|
||||
"test_logsoftmax_axis_1",
|
||||
|
@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype)
|
||||
CASE(test_eyelike_without_dtype)
|
||||
// no filter
|
||||
CASE(test_flatten_axis0)
|
||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
||||
SKIP;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_flatten_axis1)
|
||||
// no filter
|
||||
CASE(test_flatten_axis2)
|
||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
||||
SKIP;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_flatten_axis3)
|
||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
||||
SKIP;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_flatten_default_axis)
|
||||
// no filter
|
||||
CASE(test_flatten_negative_axis1)
|
||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
||||
SKIP;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_flatten_negative_axis2)
|
||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
||||
SKIP;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_flatten_negative_axis3)
|
||||
// no filter
|
||||
CASE(test_flatten_negative_axis4)
|
||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
||||
SKIP;
|
||||
#endif
|
||||
// no filter
|
||||
CASE(test_floor)
|
||||
// no filter
|
||||
CASE(test_floor_example)
|
||||
|
@ -7,12 +7,6 @@
|
||||
"test_castlike_FLOAT_to_STRING_expanded",
|
||||
"test_castlike_STRING_to_FLOAT_expanded",
|
||||
"test_concat_1d_axis_negative_1",
|
||||
"test_flatten_axis0",
|
||||
"test_flatten_axis2",
|
||||
"test_flatten_axis3",
|
||||
"test_flatten_negative_axis1",
|
||||
"test_flatten_negative_axis2",
|
||||
"test_flatten_negative_axis4",
|
||||
"test_logsoftmax_default_axis",
|
||||
"test_maxpool_2d_dilations",
|
||||
"test_maxpool_2d_same_lower",
|
||||
|
Loading…
Reference in New Issue
Block a user