From e63690a2d9b0fb50c687b239f2d699436b442447 Mon Sep 17 00:00:00 2001 From: ecchen Date: Sat, 6 Apr 2024 13:55:17 +0000 Subject: [PATCH] Add a shape checker for tflite models --- modules/dnn/test/test_tflite_importer.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/dnn/test/test_tflite_importer.cpp b/modules/dnn/test/test_tflite_importer.cpp index 291d1f50d2..7ad62bf308 100644 --- a/modules/dnn/test/test_tflite_importer.cpp +++ b/modules/dnn/test/test_tflite_importer.cpp @@ -58,7 +58,13 @@ void Test_TFLite::testModel(Net& net, const std::string& modelName, const Mat& i ASSERT_EQ(outs.size(), outNames.size()); for (int i = 0; i < outNames.size(); ++i) { Mat ref = blobFromNPY(findDataFile(format("dnn/tflite/%s_out_%s.npy", modelName.c_str(), outNames[i].c_str()))); - normAssert(ref.reshape(1, 1), outs[i].reshape(1, 1), outNames[i].c_str(), l1, lInf); + // A workaround solution for the following cases due to inconsistent shape definitions. + // The details please see: https://github.com/opencv/opencv/pull/25297#issuecomment-2039081369 + if (modelName == "face_landmark" || modelName == "selfie_segmentation") { + ref = ref.reshape(1, 1); + outs[i] = outs[i].reshape(1, 1); + } + normAssert(ref, outs[i], outNames[i].c_str(), l1, lInf); } }