mirror of
https://github.com/opencv/opencv.git
synced 2025-06-12 12:22:51 +08:00
fix Flatten layer
This commit is contained in:
parent
f071207463
commit
fec2c7e715
@ -100,7 +100,6 @@ public:
|
|||||||
{
|
{
|
||||||
outputShapeVec.push_back(inputs[0][i]);
|
outputShapeVec.push_back(inputs[0][i]);
|
||||||
}
|
}
|
||||||
CV_Assert(outputShapeVec.size() <= 4);
|
|
||||||
|
|
||||||
outputs.resize(inputs.size(), outputShapeVec);
|
outputs.resize(inputs.size(), outputShapeVec);
|
||||||
|
|
||||||
|
@ -1781,20 +1781,67 @@ void ONNXImporter::parseSqueeze(LayerParams& layerParams, const opencv_onnx::Nod
|
|||||||
addLayer(layerParams, node_proto);
|
addLayer(layerParams, node_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
void ONNXImporter::parseFlatten(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto_)
|
||||||
{
|
{
|
||||||
|
opencv_onnx::NodeProto node_proto = node_proto_;
|
||||||
CV_CheckEQ(node_proto.input_size(), 1, "");
|
CV_CheckEQ(node_proto.input_size(), 1, "");
|
||||||
|
int axis_ = layerParams.get<int>("axis", 1);
|
||||||
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
if (constBlobs.find(node_proto.input(0)) != constBlobs.end())
|
||||||
{
|
{
|
||||||
Mat input = getBlob(node_proto, 0);
|
Mat input = getBlob(node_proto, 0);
|
||||||
int axis = normalize_axis(layerParams.get<int>("axis", 1), input.dims);
|
int axis = normalize_axis(axis_, input.dims);
|
||||||
|
|
||||||
std::vector<int> out_size(&input.size[0], &input.size[0] + axis);
|
int out_size[2] = {1, 1};
|
||||||
out_size.push_back(input.total(axis));
|
for (int i = 0; i < axis; ++i)
|
||||||
Mat output = input.reshape(1, out_size);
|
{
|
||||||
|
out_size[0] *= input.size[i];
|
||||||
|
}
|
||||||
|
for (int i = axis; i < input.dims; ++i)
|
||||||
|
{
|
||||||
|
out_size[1] *= input.size[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Mat output = input.reshape(1, 2, out_size);
|
||||||
addConstant(layerParams.name, output);
|
addConstant(layerParams.name, output);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
IterShape_t shapeIt = outShapes.find(node_proto.input(0));
|
||||||
|
CV_Assert(shapeIt != outShapes.end());
|
||||||
|
MatShape inpShape = shapeIt->second;
|
||||||
|
int axis = normalize_axis(axis_, inpShape.size());
|
||||||
|
|
||||||
|
if (axis == 0 || axis == inpShape.size())
|
||||||
|
{
|
||||||
|
LayerParams reshapeLp;
|
||||||
|
reshapeLp.name = layerParams.name + "/reshape";
|
||||||
|
reshapeLp.type = "Reshape";
|
||||||
|
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
|
||||||
|
|
||||||
|
inpShape.insert(axis == 0 ? inpShape.begin() : inpShape.end(), 1);
|
||||||
|
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
|
||||||
|
|
||||||
|
opencv_onnx::NodeProto proto;
|
||||||
|
proto.add_input(node_proto.input(0));
|
||||||
|
proto.add_output(reshapeLp.name);
|
||||||
|
addLayer(reshapeLp, proto);
|
||||||
|
node_proto.set_input(0, reshapeLp.name);
|
||||||
|
axis += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
LayerParams first_pass;
|
||||||
|
first_pass.name = layerParams.name + "/flatten";
|
||||||
|
CV_Assert(layer_id.find(first_pass.name) == layer_id.end());
|
||||||
|
first_pass.type = "Flatten";
|
||||||
|
first_pass.set("axis", 0);
|
||||||
|
first_pass.set("end_axis", axis - 1);
|
||||||
|
|
||||||
|
opencv_onnx::NodeProto proto;
|
||||||
|
proto.add_input(node_proto.input(0));
|
||||||
|
proto.add_output(first_pass.name);
|
||||||
|
addLayer(first_pass, proto);
|
||||||
|
|
||||||
|
layerParams.set("axis", 1);
|
||||||
|
node_proto.set_input(0, first_pass.name);
|
||||||
addLayer(layerParams, node_proto);
|
addLayer(layerParams, node_proto);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -17,12 +17,6 @@
|
|||||||
"test_elu",
|
"test_elu",
|
||||||
"test_elu_default",
|
"test_elu_default",
|
||||||
"test_exp",
|
"test_exp",
|
||||||
"test_flatten_axis0",
|
|
||||||
"test_flatten_axis2",
|
|
||||||
"test_flatten_axis3",
|
|
||||||
"test_flatten_negative_axis1",
|
|
||||||
"test_flatten_negative_axis2",
|
|
||||||
"test_flatten_negative_axis4",
|
|
||||||
"test_leakyrelu",
|
"test_leakyrelu",
|
||||||
"test_leakyrelu_default",
|
"test_leakyrelu_default",
|
||||||
"test_logsoftmax_axis_1",
|
"test_logsoftmax_axis_1",
|
||||||
|
@ -561,35 +561,23 @@ CASE(test_eyelike_with_dtype)
|
|||||||
CASE(test_eyelike_without_dtype)
|
CASE(test_eyelike_without_dtype)
|
||||||
// no filter
|
// no filter
|
||||||
CASE(test_flatten_axis0)
|
CASE(test_flatten_axis0)
|
||||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
// no filter
|
||||||
SKIP;
|
|
||||||
#endif
|
|
||||||
CASE(test_flatten_axis1)
|
CASE(test_flatten_axis1)
|
||||||
// no filter
|
// no filter
|
||||||
CASE(test_flatten_axis2)
|
CASE(test_flatten_axis2)
|
||||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
// no filter
|
||||||
SKIP;
|
|
||||||
#endif
|
|
||||||
CASE(test_flatten_axis3)
|
CASE(test_flatten_axis3)
|
||||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
// no filter
|
||||||
SKIP;
|
|
||||||
#endif
|
|
||||||
CASE(test_flatten_default_axis)
|
CASE(test_flatten_default_axis)
|
||||||
// no filter
|
// no filter
|
||||||
CASE(test_flatten_negative_axis1)
|
CASE(test_flatten_negative_axis1)
|
||||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
// no filter
|
||||||
SKIP;
|
|
||||||
#endif
|
|
||||||
CASE(test_flatten_negative_axis2)
|
CASE(test_flatten_negative_axis2)
|
||||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
// no filter
|
||||||
SKIP;
|
|
||||||
#endif
|
|
||||||
CASE(test_flatten_negative_axis3)
|
CASE(test_flatten_negative_axis3)
|
||||||
// no filter
|
// no filter
|
||||||
CASE(test_flatten_negative_axis4)
|
CASE(test_flatten_negative_axis4)
|
||||||
#if INF_ENGINE_VER_MAJOR_EQ(2021040000)
|
// no filter
|
||||||
SKIP;
|
|
||||||
#endif
|
|
||||||
CASE(test_floor)
|
CASE(test_floor)
|
||||||
// no filter
|
// no filter
|
||||||
CASE(test_floor_example)
|
CASE(test_floor_example)
|
||||||
|
@ -7,12 +7,6 @@
|
|||||||
"test_castlike_FLOAT_to_STRING_expanded",
|
"test_castlike_FLOAT_to_STRING_expanded",
|
||||||
"test_castlike_STRING_to_FLOAT_expanded",
|
"test_castlike_STRING_to_FLOAT_expanded",
|
||||||
"test_concat_1d_axis_negative_1",
|
"test_concat_1d_axis_negative_1",
|
||||||
"test_flatten_axis0",
|
|
||||||
"test_flatten_axis2",
|
|
||||||
"test_flatten_axis3",
|
|
||||||
"test_flatten_negative_axis1",
|
|
||||||
"test_flatten_negative_axis2",
|
|
||||||
"test_flatten_negative_axis4",
|
|
||||||
"test_logsoftmax_default_axis",
|
"test_logsoftmax_default_axis",
|
||||||
"test_maxpool_2d_dilations",
|
"test_maxpool_2d_dilations",
|
||||||
"test_maxpool_2d_same_lower",
|
"test_maxpool_2d_same_lower",
|
||||||
|
Loading…
Reference in New Issue
Block a user