mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Fix axis
This commit is contained in:
parent
832ca0734d
commit
35c24480ae
@ -1967,7 +1967,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
LayerParams reshapeLp;
|
||||
std::string reshapeName = name + "/reshape";
|
||||
CV_Assert(layer_id.find(reshapeName) == layer_id.end());
|
||||
reshapeLp.set("axis", indices.at<int>(0));
|
||||
reshapeLp.set("axis", 0);
|
||||
reshapeLp.set("num_axes", 1);
|
||||
int newShape[] = {1, 1, -1};
|
||||
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 3));
|
||||
@ -1990,7 +1990,7 @@ void TFImporter::populateNet(Net dstNet)
|
||||
LayerParams sliceLp;
|
||||
std::string layerShapeName = name + "/slice";
|
||||
CV_Assert(layer_id.find(layerShapeName) == layer_id.end());
|
||||
sliceLp.set("axis", indices.at<int>(0));
|
||||
sliceLp.set("axis", 0);
|
||||
int begin[] = {0};
|
||||
int size[] = {1};
|
||||
sliceLp.set("begin", DictValue::arrayInt(&begin[0], 1));
|
||||
@ -2004,8 +2004,8 @@ void TFImporter::populateNet(Net dstNet)
|
||||
LayerParams squeezeLp;
|
||||
std::string squeezeName = name + "/squeeze";
|
||||
CV_Assert(layer_id.find(squeezeName) == layer_id.end());
|
||||
squeezeLp.set("axis", indices.at<int>(0));
|
||||
squeezeLp.set("end_axis", indices.at<int>(0) + 1);
|
||||
squeezeLp.set("axis", 0);
|
||||
squeezeLp.set("end_axis", 1);
|
||||
int squeezeId = dstNet.addLayer(squeezeName, "Flatten", squeezeLp);
|
||||
layer_id[squeezeName] = squeezeId;
|
||||
connect(layer_id, dstNet, Pin(layerShapeName), squeezeId, 0);
|
||||
|
Loading…
Reference in New Issue
Block a user