mirror of
https://github.com/opencv/opencv.git
synced 2025-07-24 14:06:27 +08:00
Merge pull request #22448 from Ichini24:reshape-permutations-fix
changed names of permutations if Reshpe is in NHWC
This commit is contained in:
commit
c2c8da2517
@ -1097,6 +1097,9 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD
|
||||
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
|
||||
hasSwap = true;
|
||||
}
|
||||
|
||||
bool changedType{false};
|
||||
|
||||
if (inpLayout == DATA_LAYOUT_NHWC)
|
||||
{
|
||||
if (newShapeSize >= 2 || newShape.at<int>(1) == 1)
|
||||
@ -1110,23 +1113,28 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD
|
||||
else
|
||||
{
|
||||
inpLayout = DATA_LAYOUT_NHWC;
|
||||
changedType = newShapeSize == 4 && !hasSwap;
|
||||
}
|
||||
}
|
||||
}
|
||||
layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShapeSize));
|
||||
|
||||
int id = dstNet.addLayer(name, "Reshape", layerParams);
|
||||
layer_id[name] = id;
|
||||
std::string setName = changedType ? name + "/realReshape" : name;
|
||||
|
||||
int id = dstNet.addLayer(setName, "Reshape", layerParams);
|
||||
layer_id[setName] = id;
|
||||
|
||||
// one input only
|
||||
connect(layer_id, dstNet, inpId, id, 0);
|
||||
inpId = Pin(name);
|
||||
inpId = Pin(setName);
|
||||
|
||||
if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) &&
|
||||
newShapeSize == 4 && !hasSwap)
|
||||
{
|
||||
int order[] = {0, 3, 1, 2}; // Transform back to OpenCV's NCHW.
|
||||
addPermuteLayer(order, name + "/nchw", inpId);
|
||||
|
||||
setName = changedType ? name : name + "/nchw";
|
||||
addPermuteLayer(order, setName, inpId);
|
||||
inpLayout = DATA_LAYOUT_NCHW;
|
||||
}
|
||||
|
||||
|
@ -337,6 +337,12 @@ TEST_P(Test_TensorFlow_layers, eltwise_mul_vec)
|
||||
runTensorFlowNet("eltwise_mul_vec");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, tf_reshape_nhwc)
|
||||
{
|
||||
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
|
||||
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
|
||||
runTensorFlowNet("tf_reshape_nhwc");
|
||||
}
|
||||
|
||||
TEST_P(Test_TensorFlow_layers, channel_broadcast)
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user