Add a shape checker for tflite models

pull/25372/head
ecchen 10 months ago committed by CNOCycle
parent e80500828c
commit e63690a2d9
  1. 8
      modules/dnn/test/test_tflite_importer.cpp

@ -58,7 +58,13 @@ void Test_TFLite::testModel(Net& net, const std::string& modelName, const Mat& i
ASSERT_EQ(outs.size(), outNames.size()); ASSERT_EQ(outs.size(), outNames.size());
for (int i = 0; i < outNames.size(); ++i) { for (int i = 0; i < outNames.size(); ++i) {
Mat ref = blobFromNPY(findDataFile(format("dnn/tflite/%s_out_%s.npy", modelName.c_str(), outNames[i].c_str()))); Mat ref = blobFromNPY(findDataFile(format("dnn/tflite/%s_out_%s.npy", modelName.c_str(), outNames[i].c_str())));
normAssert(ref.reshape(1, 1), outs[i].reshape(1, 1), outNames[i].c_str(), l1, lInf); // A workaround solution for the following cases due to inconsistent shape definitions.
// The details please see: https://github.com/opencv/opencv/pull/25297#issuecomment-2039081369
if (modelName == "face_landmark" || modelName == "selfie_segmentation") {
ref = ref.reshape(1, 1);
outs[i] = outs[i].reshape(1, 1);
}
normAssert(ref, outs[i], outNames[i].c_str(), l1, lInf);
} }
} }

Loading…
Cancel
Save