From 587b4451aca43fb556e687c620ad434af13bfed6 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Fri, 2 Sep 2022 20:21:20 +0800 Subject: [PATCH] Fix bugs in unittests --- tests/transforms/test_functions.py | 26 +++++++++++++++++--------- tests/transforms/test_operators.py | 9 +++++---- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/tests/transforms/test_functions.py b/tests/transforms/test_functions.py index 692d8a3..3e8d9bb 100644 --- a/tests/transforms/test_functions.py +++ b/tests/transforms/test_functions.py @@ -14,6 +14,8 @@ import copy +import numpy as np + import paddlers.transforms as T from testing_utils import CpuCommonTest from data import build_input_from_file @@ -22,7 +24,9 @@ __all__ = ['TestMatchHistograms', 'TestMatchByRegression'] def calc_err(a, b): - return (a - b).abs().mean() + a = a.astype('float64') + b = b.astype('float64') + return np.abs(a - b).mean() class TestMatchHistograms(CpuCommonTest): @@ -37,12 +41,14 @@ class TestMatchHistograms(CpuCommonTest): for input in copy.deepcopy(self.inputs): for sample in input: sample = decoder(sample) - im_out = T.functions.match_histograms(sample['image'], - sample['image2']) - self.check_output_equal(im_out.shape, sample['image2'].shape) - self.assertEqual(im_out.dtype, sample['image2'].dtype) + im_out = T.functions.match_histograms(sample['image2'], sample['image']) + self.check_output_equal(im_out.shape, sample['image2'].shape) + self.assertEqual(im_out.dtype, sample['image2'].dtype) + + im_out = T.functions.match_histograms(sample['image'], + sample['image2']) self.check_output_equal(im_out.shape, sample['image'].shape) self.assertEqual(im_out.dtype, sample['image'].dtype) @@ -59,15 +65,17 @@ class TestMatchByRegression(CpuCommonTest): for input in copy.deepcopy(self.inputs): for sample in input: sample = decoder(sample) - im_out = T.functions.match_by_regression(sample['image'], - sample['image2']) + + im_out = T.functions.match_by_regression(sample['image2'], + sample['image']) self.check_output_equal(im_out.shape, sample['image2'].shape) self.assertEqual(im_out.dtype, sample['image2'].dtype) err1 = calc_err(sample['image'], sample['image2']) err2 = calc_err(sample['image'], im_out) + self.assertLessEqual(err2, err1) - im_out = T.functions.match_by_regression(sample['image2'], - sample['image']) + im_out = T.functions.match_by_regression(sample['image'], + sample['image2']) self.check_output_equal(im_out.shape, sample['image'].shape) self.assertEqual(im_out.dtype, sample['image'].dtype) err1 = calc_err(sample['image'], sample['image2']) diff --git a/tests/transforms/test_operators.py b/tests/transforms/test_operators.py index 6b3f288..0d8024d 100644 --- a/tests/transforms/test_operators.py +++ b/tests/transforms/test_operators.py @@ -145,8 +145,7 @@ OP2FILTER = { 'SelectBand': _filter_no_sar, 'Dehaze': _filter_only_optical, 'Normalize': _filter_only_optical, - 'RandomDistort': _filter_only_optical, - 'MatchRadiance': _filter_only_mt + 'RandomDistort': _filter_only_optical } @@ -352,9 +351,11 @@ class TestTransform(CpuCommonTest): test_func(self) def test_MatchRadiance(self): - test_hist = make_test_func(T.MatchRadiance, 'hist') + test_hist = make_test_func( + T.MatchRadiance, 'hist', _filter=_filter_only_mt) test_hist(self) - test_lsr = make_test_func(T.MatchRadiance, 'lsr') + test_lsr = make_test_func( + T.MatchRadiance, 'lsr', _filter=_filter_only_mt) test_lsr(self)