This commit is contained in:
Liubov Batanina 2020-01-22 13:36:29 +03:00
parent 832ca0734d
commit 35c24480ae

View File

@ -1967,7 +1967,7 @@ void TFImporter::populateNet(Net dstNet)
LayerParams reshapeLp;
std::string reshapeName = name + "/reshape";
CV_Assert(layer_id.find(reshapeName) == layer_id.end());
reshapeLp.set("axis", indices.at<int>(0));
reshapeLp.set("axis", 0);
reshapeLp.set("num_axes", 1);
int newShape[] = {1, 1, -1};
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 3));
@ -1990,7 +1990,7 @@ void TFImporter::populateNet(Net dstNet)
LayerParams sliceLp;
std::string layerShapeName = name + "/slice";
CV_Assert(layer_id.find(layerShapeName) == layer_id.end());
sliceLp.set("axis", indices.at<int>(0));
sliceLp.set("axis", 0);
int begin[] = {0};
int size[] = {1};
sliceLp.set("begin", DictValue::arrayInt(&begin[0], 1));
@ -2004,8 +2004,8 @@ void TFImporter::populateNet(Net dstNet)
LayerParams squeezeLp;
std::string squeezeName = name + "/squeeze";
CV_Assert(layer_id.find(squeezeName) == layer_id.end());
squeezeLp.set("axis", indices.at<int>(0));
squeezeLp.set("end_axis", indices.at<int>(0) + 1);
squeezeLp.set("axis", 0);
squeezeLp.set("end_axis", 1);
int squeezeId = dstNet.addLayer(squeezeName, "Flatten", squeezeLp);
layer_id[squeezeName] = squeezeId;
connect(layer_id, dstNet, Pin(layerShapeName), squeezeId, 0);