mirror of
https://github.com/opencv/opencv.git
synced 2025-01-18 22:44:02 +08:00
Add a shape checker for tflite models
This commit is contained in:
parent
e80500828c
commit
e63690a2d9
@ -58,7 +58,13 @@ void Test_TFLite::testModel(Net& net, const std::string& modelName, const Mat& i
|
||||
ASSERT_EQ(outs.size(), outNames.size());
|
||||
for (int i = 0; i < outNames.size(); ++i) {
|
||||
Mat ref = blobFromNPY(findDataFile(format("dnn/tflite/%s_out_%s.npy", modelName.c_str(), outNames[i].c_str())));
|
||||
normAssert(ref.reshape(1, 1), outs[i].reshape(1, 1), outNames[i].c_str(), l1, lInf);
|
||||
// A workaround solution for the following cases due to inconsistent shape definitions.
|
||||
// The details please see: https://github.com/opencv/opencv/pull/25297#issuecomment-2039081369
|
||||
if (modelName == "face_landmark" || modelName == "selfie_segmentation") {
|
||||
ref = ref.reshape(1, 1);
|
||||
outs[i] = outs[i].reshape(1, 1);
|
||||
}
|
||||
normAssert(ref, outs[i], outNames[i].c_str(), l1, lInf);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user