mirror of
https://github.com/opencv/opencv.git
synced 2024-11-25 19:50:38 +08:00
Merge pull request #15303 from dkurt:fix_15296
This commit is contained in:
commit
84b8a2fb05
@ -1129,15 +1129,14 @@ void TFImporter::populateNet(Net dstNet)
|
|||||||
if (value_id.find(layer.input(1)) != value_id.end())
|
if (value_id.find(layer.input(1)) != value_id.end())
|
||||||
{
|
{
|
||||||
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
|
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
|
||||||
|
if (newShape.total() == 4)
|
||||||
|
{
|
||||||
|
// NHWC->NCHW
|
||||||
|
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
|
||||||
|
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
|
||||||
|
}
|
||||||
if (inpLayout == DATA_LAYOUT_NHWC)
|
if (inpLayout == DATA_LAYOUT_NHWC)
|
||||||
{
|
{
|
||||||
if (newShape.total() == 4)
|
|
||||||
{
|
|
||||||
// NHWC->NCHW
|
|
||||||
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
|
|
||||||
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
|
|
||||||
}
|
|
||||||
if (newShape.total() != 4 || newShape.at<int>(1) == 1)
|
if (newShape.total() != 4 || newShape.at<int>(1) == 1)
|
||||||
{
|
{
|
||||||
LayerParams permLP;
|
LayerParams permLP;
|
||||||
|
@ -279,7 +279,7 @@ TEST_P(Test_TensorFlow_layers, matmul)
|
|||||||
// Reference output values are in range [-5.688, 4.484]
|
// Reference output values are in range [-5.688, 4.484]
|
||||||
double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1;
|
double l1 = target == DNN_TARGET_MYRIAD ? 6.1e-3 : default_l1;
|
||||||
runTensorFlowNet("nhwc_reshape_matmul", false, l1);
|
runTensorFlowNet("nhwc_reshape_matmul", false, l1);
|
||||||
|
runTensorFlowNet("matmul_layout");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(Test_TensorFlow_layers, reshape)
|
TEST_P(Test_TensorFlow_layers, reshape)
|
||||||
|
Loading…
Reference in New Issue
Block a user