From ee2613c4f1707b222ed372374581670275c7558b Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Fri, 2 Sep 2022 10:36:03 +0800 Subject: [PATCH] Add MatchRadiance op --- docs/intro/transforms.md | 1 + paddlers/transforms/operators.py | 41 +++++++++++++++++++++++++++--- tests/transforms/test_functions.py | 14 ++++++++++ tests/transforms/test_operators.py | 40 ++++++++++++++++------------- 4 files changed, 76 insertions(+), 20 deletions(-) diff --git a/docs/intro/transforms.md b/docs/intro/transforms.md index c7234de..8bc25ee 100644 --- a/docs/intro/transforms.md +++ b/docs/intro/transforms.md @@ -20,6 +20,7 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为 | RandomExpand | 根据随机偏移扩展输入影像。 | 所有任务 | ... | | Pad | 将输入影像填充到指定的大小。 | 所有任务 | ... | | MixupImage | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本。 | 目标检测 | ... | +| MatchRadiance | 对两个时相的输入影像进行相对辐射校正。 | 变化检测 | ... | | RandomDistort | 对输入施加随机色彩变换。 | 所有任务 | ... | | RandomBlur | 对输入施加随机模糊。 | 所有任务 | ... | | Dehaze | 对输入图像进行去雾。 | 所有任务 | ... | diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index dd21c7a..143ea6a 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -35,7 +35,8 @@ from .functions import ( horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8, - img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape) + img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape, + match_by_regression, match_histograms) __all__ = [ "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort", @@ -43,8 +44,9 @@ __all__ = [ "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop", "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort", "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand", - "ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier", - "ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask" + "RandomFlipOrRotate", "ReloadMask", "MatchRadiance", "ArrangeSegmenter", + "ArrangeChangeDetector", "ArrangeClassifier", "ArrangeDetector", + "ArrangeRestorer" ] interp_dict = { @@ -1928,6 +1930,39 @@ class ReloadMask(Transform): return sample +class MatchRadiance(Transform): + """ + Perform relative radiometric correction between bi-temporal images. + + Args: + method (str, optional): Method used to match the radiance of the + bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands + for histogram matching and 'lsr' stands for least-squares + regression. Default: 'hist'. + """ + + def __init__(self, method='hist'): + super(MatchRadiance, self).__init__() + + if method == 'hist': + self._match_func = match_histograms + elif method == 'lsr': + self._match_func = match_by_regression + else: + raise ValueError( + "{} is not a supported radiometric correction method.".format( + method)) + + self.method = method + + def apply(self, sample): + if 'image2' not in sample: + raise ValueError("'image2' is not found in the sample.") + + sample['image2'] = self._match_func(sample['image2'], sample['image']) + return sample + + class Arrange(Transform): def __init__(self, mode): super().__init__() diff --git a/tests/transforms/test_functions.py b/tests/transforms/test_functions.py index f51152e..692d8a3 100644 --- a/tests/transforms/test_functions.py +++ b/tests/transforms/test_functions.py @@ -21,6 +21,10 @@ from data import build_input_from_file __all__ = ['TestMatchHistograms', 'TestMatchByRegression'] +def calc_err(a, b): + return (a - b).abs().mean() + + class TestMatchHistograms(CpuCommonTest): def setUp(self): self.inputs = [ @@ -36,9 +40,11 @@ class TestMatchHistograms(CpuCommonTest): im_out = T.functions.match_histograms(sample['image'], sample['image2']) self.check_output_equal(im_out.shape, sample['image2'].shape) + self.assertEqual(im_out.dtype, sample['image2'].dtype) im_out = T.functions.match_histograms(sample['image2'], sample['image']) self.check_output_equal(im_out.shape, sample['image'].shape) + self.assertEqual(im_out.dtype, sample['image'].dtype) class TestMatchByRegression(CpuCommonTest): @@ -56,6 +62,14 @@ class TestMatchByRegression(CpuCommonTest): im_out = T.functions.match_by_regression(sample['image'], sample['image2']) self.check_output_equal(im_out.shape, sample['image2'].shape) + self.assertEqual(im_out.dtype, sample['image2'].dtype) + err1 = calc_err(sample['image'], sample['image2']) + err2 = calc_err(sample['image'], im_out) + self.assertLessEqual(err2, err1) im_out = T.functions.match_by_regression(sample['image2'], sample['image']) self.check_output_equal(im_out.shape, sample['image'].shape) + self.assertEqual(im_out.dtype, sample['image'].dtype) + err1 = calc_err(sample['image'], sample['image2']) + err2 = calc_err(im_out, sample['image2']) + self.assertLessEqual(err2, err1) diff --git a/tests/transforms/test_operators.py b/tests/transforms/test_operators.py index cff8428..c969e06 100644 --- a/tests/transforms/test_operators.py +++ b/tests/transforms/test_operators.py @@ -54,30 +54,30 @@ def _add_op_tests(cls): filter_ = OP2FILTER.get(op_name, None) setattr( cls, attr_name, make_test_func( - op_class, filter_=filter_)) + op_class, _filter=filter_)) return cls def make_test_func(op_class, *args, - in_hook=None, - out_hook=None, - filter_=None, + _in_hook=None, + _out_hook=None, + _filter=None, **kwargs): def _test_func(self): op = op_class(*args, **kwargs) decoder = T.DecodeImg() inputs = map(decoder, copy.deepcopy(self.inputs)) for i, input_ in enumerate(inputs): - if filter_ is not None: - input_ = filter_(input_) + if _filter is not None: + input_ = _filter(input_) with self.subTest(i=i): for sample in input_: - if in_hook: - sample = in_hook(sample) + if _in_hook: + sample = _in_hook(sample) sample = op(sample) - if out_hook: - sample = out_hook(sample) + if _out_hook: + sample = _out_hook(sample) return _test_func @@ -308,15 +308,15 @@ class TestTransform(CpuCommonTest): test_func_not_keep_ratio = make_test_func( T.Resize, - in_hook=_in_hook, - out_hook=_out_hook_not_keep_ratio, + _in_hook=_in_hook, + _out_hook=_out_hook_not_keep_ratio, target_size=TARGET_SIZE, keep_ratio=False) test_func_not_keep_ratio(self) test_func_keep_ratio = make_test_func( T.Resize, - in_hook=_in_hook, - out_hook=_out_hook_keep_ratio, + _in_hook=_in_hook, + _out_hook=_out_hook_keep_ratio, target_size=TARGET_SIZE, keep_ratio=True) test_func_keep_ratio(self) @@ -345,11 +345,17 @@ class TestTransform(CpuCommonTest): test_func = make_test_func( T.RandomFlipOrRotate, - in_hook=_in_hook, - out_hook=_out_hook, - filter_=_filter_no_det) + _in_hook=_in_hook, + _out_hook=_out_hook, + _filter=_filter_no_det) test_func(self) + def test_MatchRadiance(self): + test_hist = make_test_func(T.MatchRadiance, 'hist') + test_hist(self) + test_lsr = make_test_func(T.MatchRadiance, 'lsr') + test_lsr(self) + class TestCompose(CpuCommonTest): pass