From e7015f6ae872b2ca2260d4387d74731ef9dc26f8 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Fri, 19 Oct 2018 17:43:26 +0300 Subject: [PATCH] Fix ENet test --- modules/dnn/test/test_torch_importer.cpp | 46 ++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/modules/dnn/test/test_torch_importer.cpp b/modules/dnn/test/test_torch_importer.cpp index dd7d975af6..0b844452e2 100644 --- a/modules/dnn/test/test_torch_importer.cpp +++ b/modules/dnn/test/test_torch_importer.cpp @@ -287,6 +287,46 @@ TEST_P(Test_Torch_nets, OpenFace_accuracy) normAssert(out, outRef, "", default_l1, default_lInf); } +static Mat getSegmMask(const Mat& scores) +{ + const int rows = scores.size[2]; + const int cols = scores.size[3]; + const int numClasses = scores.size[1]; + + Mat maxCl = Mat::zeros(rows, cols, CV_8UC1); + Mat maxVal(rows, cols, CV_32FC1, Scalar(0)); + for (int ch = 0; ch < numClasses; ch++) + { + for (int row = 0; row < rows; row++) + { + const float *ptrScore = scores.ptr(0, ch, row); + uint8_t *ptrMaxCl = maxCl.ptr(row); + float *ptrMaxVal = maxVal.ptr(row); + for (int col = 0; col < cols; col++) + { + if (ptrScore[col] > ptrMaxVal[col]) + { + ptrMaxVal[col] = ptrScore[col]; + ptrMaxCl[col] = (uchar)ch; + } + } + } + } + return maxCl; +} + +// Computer per-class intersection over union metric. +static void normAssertSegmentation(const Mat& ref, const Mat& test) +{ + CV_Assert_N(ref.dims == 4, test.dims == 4); + const int numClasses = ref.size[1]; + CV_Assert(numClasses == test.size[1]); + + Mat refMask = getSegmMask(ref); + Mat testMask = getSegmMask(test); + EXPECT_EQ(countNonZero(refMask != testMask), 0); +} + TEST_P(Test_Torch_nets, ENet_accuracy) { checkBackend(); @@ -313,14 +353,16 @@ TEST_P(Test_Torch_nets, ENet_accuracy) // Due to numerical instability in Pooling-Unpooling layers (indexes jittering) // thresholds for ENet must be changed. Accuracy of results was checked on // Cityscapes dataset and difference in mIOU with Torch is 10E-4% - normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.5); + normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552); + normAssertSegmentation(ref, out); const int N = 3; for (int i = 0; i < N; i++) { net.setInput(inputBlob, ""); Mat out = net.forward(); - normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.5); + normAssert(ref, out, "", 0.00044, /*target == DNN_TARGET_CPU ? 0.453 : */0.552); + normAssertSegmentation(ref, out); } }