Fix bugs in unittests

own
Bobholamovic 2 years ago
parent 2fd1689f63
commit 587b4451ac
  1. 26
      tests/transforms/test_functions.py
  2. 9
      tests/transforms/test_operators.py

@ -14,6 +14,8 @@
import copy
import numpy as np
import paddlers.transforms as T
from testing_utils import CpuCommonTest
from data import build_input_from_file
@ -22,7 +24,9 @@ __all__ = ['TestMatchHistograms', 'TestMatchByRegression']
def calc_err(a, b):
return (a - b).abs().mean()
a = a.astype('float64')
b = b.astype('float64')
return np.abs(a - b).mean()
class TestMatchHistograms(CpuCommonTest):
@ -37,12 +41,14 @@ class TestMatchHistograms(CpuCommonTest):
for input in copy.deepcopy(self.inputs):
for sample in input:
sample = decoder(sample)
im_out = T.functions.match_histograms(sample['image'],
sample['image2'])
self.check_output_equal(im_out.shape, sample['image2'].shape)
self.assertEqual(im_out.dtype, sample['image2'].dtype)
im_out = T.functions.match_histograms(sample['image2'],
sample['image'])
self.check_output_equal(im_out.shape, sample['image2'].shape)
self.assertEqual(im_out.dtype, sample['image2'].dtype)
im_out = T.functions.match_histograms(sample['image'],
sample['image2'])
self.check_output_equal(im_out.shape, sample['image'].shape)
self.assertEqual(im_out.dtype, sample['image'].dtype)
@ -59,15 +65,17 @@ class TestMatchByRegression(CpuCommonTest):
for input in copy.deepcopy(self.inputs):
for sample in input:
sample = decoder(sample)
im_out = T.functions.match_by_regression(sample['image'],
sample['image2'])
im_out = T.functions.match_by_regression(sample['image2'],
sample['image'])
self.check_output_equal(im_out.shape, sample['image2'].shape)
self.assertEqual(im_out.dtype, sample['image2'].dtype)
err1 = calc_err(sample['image'], sample['image2'])
err2 = calc_err(sample['image'], im_out)
self.assertLessEqual(err2, err1)
im_out = T.functions.match_by_regression(sample['image2'],
sample['image'])
im_out = T.functions.match_by_regression(sample['image'],
sample['image2'])
self.check_output_equal(im_out.shape, sample['image'].shape)
self.assertEqual(im_out.dtype, sample['image'].dtype)
err1 = calc_err(sample['image'], sample['image2'])

@ -145,8 +145,7 @@ OP2FILTER = {
'SelectBand': _filter_no_sar,
'Dehaze': _filter_only_optical,
'Normalize': _filter_only_optical,
'RandomDistort': _filter_only_optical,
'MatchRadiance': _filter_only_mt
'RandomDistort': _filter_only_optical
}
@ -352,9 +351,11 @@ class TestTransform(CpuCommonTest):
test_func(self)
def test_MatchRadiance(self):
test_hist = make_test_func(T.MatchRadiance, 'hist')
test_hist = make_test_func(
T.MatchRadiance, 'hist', _filter=_filter_only_mt)
test_hist(self)
test_lsr = make_test_func(T.MatchRadiance, 'lsr')
test_lsr = make_test_func(
T.MatchRadiance, 'lsr', _filter=_filter_only_mt)
test_lsr(self)

Loading…
Cancel
Save