Add MatchRadiance op

own
Bobholamovic 2 years ago
parent 7b1c7adf4d
commit ee2613c4f1
  1. 1
      docs/intro/transforms.md
  2. 41
      paddlers/transforms/operators.py
  3. 14
      tests/transforms/test_functions.py
  4. 40
      tests/transforms/test_operators.py

@ -20,6 +20,7 @@ PaddleRS对不同遥感任务需要的数据预处理/数据增强(合称为
| RandomExpand | 根据随机偏移扩展输入影像。 | 所有任务 | ... | | RandomExpand | 根据随机偏移扩展输入影像。 | 所有任务 | ... |
| Pad | 将输入影像填充到指定的大小。 | 所有任务 | ... | | Pad | 将输入影像填充到指定的大小。 | 所有任务 | ... |
| MixupImage | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本。 | 目标检测 | ... | | MixupImage | 将两幅影像(及对应的目标检测标注)混合在一起作为新的样本。 | 目标检测 | ... |
| MatchRadiance | 对两个时相的输入影像进行相对辐射校正。 | 变化检测 | ... |
| RandomDistort | 对输入施加随机色彩变换。 | 所有任务 | ... | | RandomDistort | 对输入施加随机色彩变换。 | 所有任务 | ... |
| RandomBlur | 对输入施加随机模糊。 | 所有任务 | ... | | RandomBlur | 对输入施加随机模糊。 | 所有任务 | ... |
| Dehaze | 对输入图像进行去雾。 | 所有任务 | ... | | Dehaze | 对输入图像进行去雾。 | 所有任务 | ... |

@ -35,7 +35,8 @@ from .functions import (
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle, vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8, resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape) img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape,
match_by_regression, match_histograms)
__all__ = [ __all__ = [
"Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort", "Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort",
@ -43,8 +44,9 @@ __all__ = [
"RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop", "RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
"RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort", "RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort",
"RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand", "RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand",
"ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier", "RandomFlipOrRotate", "ReloadMask", "MatchRadiance", "ArrangeSegmenter",
"ArrangeDetector", "ArrangeRestorer", "RandomFlipOrRotate", "ReloadMask" "ArrangeChangeDetector", "ArrangeClassifier", "ArrangeDetector",
"ArrangeRestorer"
] ]
interp_dict = { interp_dict = {
@ -1928,6 +1930,39 @@ class ReloadMask(Transform):
return sample return sample
class MatchRadiance(Transform):
"""
Perform relative radiometric correction between bi-temporal images.
Args:
method (str, optional): Method used to match the radiance of the
bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands
for histogram matching and 'lsr' stands for least-squares
regression. Default: 'hist'.
"""
def __init__(self, method='hist'):
super(MatchRadiance, self).__init__()
if method == 'hist':
self._match_func = match_histograms
elif method == 'lsr':
self._match_func = match_by_regression
else:
raise ValueError(
"{} is not a supported radiometric correction method.".format(
method))
self.method = method
def apply(self, sample):
if 'image2' not in sample:
raise ValueError("'image2' is not found in the sample.")
sample['image2'] = self._match_func(sample['image2'], sample['image'])
return sample
class Arrange(Transform): class Arrange(Transform):
def __init__(self, mode): def __init__(self, mode):
super().__init__() super().__init__()

@ -21,6 +21,10 @@ from data import build_input_from_file
__all__ = ['TestMatchHistograms', 'TestMatchByRegression'] __all__ = ['TestMatchHistograms', 'TestMatchByRegression']
def calc_err(a, b):
return (a - b).abs().mean()
class TestMatchHistograms(CpuCommonTest): class TestMatchHistograms(CpuCommonTest):
def setUp(self): def setUp(self):
self.inputs = [ self.inputs = [
@ -36,9 +40,11 @@ class TestMatchHistograms(CpuCommonTest):
im_out = T.functions.match_histograms(sample['image'], im_out = T.functions.match_histograms(sample['image'],
sample['image2']) sample['image2'])
self.check_output_equal(im_out.shape, sample['image2'].shape) self.check_output_equal(im_out.shape, sample['image2'].shape)
self.assertEqual(im_out.dtype, sample['image2'].dtype)
im_out = T.functions.match_histograms(sample['image2'], im_out = T.functions.match_histograms(sample['image2'],
sample['image']) sample['image'])
self.check_output_equal(im_out.shape, sample['image'].shape) self.check_output_equal(im_out.shape, sample['image'].shape)
self.assertEqual(im_out.dtype, sample['image'].dtype)
class TestMatchByRegression(CpuCommonTest): class TestMatchByRegression(CpuCommonTest):
@ -56,6 +62,14 @@ class TestMatchByRegression(CpuCommonTest):
im_out = T.functions.match_by_regression(sample['image'], im_out = T.functions.match_by_regression(sample['image'],
sample['image2']) sample['image2'])
self.check_output_equal(im_out.shape, sample['image2'].shape) self.check_output_equal(im_out.shape, sample['image2'].shape)
self.assertEqual(im_out.dtype, sample['image2'].dtype)
err1 = calc_err(sample['image'], sample['image2'])
err2 = calc_err(sample['image'], im_out)
self.assertLessEqual(err2, err1)
im_out = T.functions.match_by_regression(sample['image2'], im_out = T.functions.match_by_regression(sample['image2'],
sample['image']) sample['image'])
self.check_output_equal(im_out.shape, sample['image'].shape) self.check_output_equal(im_out.shape, sample['image'].shape)
self.assertEqual(im_out.dtype, sample['image'].dtype)
err1 = calc_err(sample['image'], sample['image2'])
err2 = calc_err(im_out, sample['image2'])
self.assertLessEqual(err2, err1)

@ -54,30 +54,30 @@ def _add_op_tests(cls):
filter_ = OP2FILTER.get(op_name, None) filter_ = OP2FILTER.get(op_name, None)
setattr( setattr(
cls, attr_name, make_test_func( cls, attr_name, make_test_func(
op_class, filter_=filter_)) op_class, _filter=filter_))
return cls return cls
def make_test_func(op_class, def make_test_func(op_class,
*args, *args,
in_hook=None, _in_hook=None,
out_hook=None, _out_hook=None,
filter_=None, _filter=None,
**kwargs): **kwargs):
def _test_func(self): def _test_func(self):
op = op_class(*args, **kwargs) op = op_class(*args, **kwargs)
decoder = T.DecodeImg() decoder = T.DecodeImg()
inputs = map(decoder, copy.deepcopy(self.inputs)) inputs = map(decoder, copy.deepcopy(self.inputs))
for i, input_ in enumerate(inputs): for i, input_ in enumerate(inputs):
if filter_ is not None: if _filter is not None:
input_ = filter_(input_) input_ = _filter(input_)
with self.subTest(i=i): with self.subTest(i=i):
for sample in input_: for sample in input_:
if in_hook: if _in_hook:
sample = in_hook(sample) sample = _in_hook(sample)
sample = op(sample) sample = op(sample)
if out_hook: if _out_hook:
sample = out_hook(sample) sample = _out_hook(sample)
return _test_func return _test_func
@ -308,15 +308,15 @@ class TestTransform(CpuCommonTest):
test_func_not_keep_ratio = make_test_func( test_func_not_keep_ratio = make_test_func(
T.Resize, T.Resize,
in_hook=_in_hook, _in_hook=_in_hook,
out_hook=_out_hook_not_keep_ratio, _out_hook=_out_hook_not_keep_ratio,
target_size=TARGET_SIZE, target_size=TARGET_SIZE,
keep_ratio=False) keep_ratio=False)
test_func_not_keep_ratio(self) test_func_not_keep_ratio(self)
test_func_keep_ratio = make_test_func( test_func_keep_ratio = make_test_func(
T.Resize, T.Resize,
in_hook=_in_hook, _in_hook=_in_hook,
out_hook=_out_hook_keep_ratio, _out_hook=_out_hook_keep_ratio,
target_size=TARGET_SIZE, target_size=TARGET_SIZE,
keep_ratio=True) keep_ratio=True)
test_func_keep_ratio(self) test_func_keep_ratio(self)
@ -345,11 +345,17 @@ class TestTransform(CpuCommonTest):
test_func = make_test_func( test_func = make_test_func(
T.RandomFlipOrRotate, T.RandomFlipOrRotate,
in_hook=_in_hook, _in_hook=_in_hook,
out_hook=_out_hook, _out_hook=_out_hook,
filter_=_filter_no_det) _filter=_filter_no_det)
test_func(self) test_func(self)
def test_MatchRadiance(self):
test_hist = make_test_func(T.MatchRadiance, 'hist')
test_hist(self)
test_lsr = make_test_func(T.MatchRadiance, 'lsr')
test_lsr(self)
class TestCompose(CpuCommonTest): class TestCompose(CpuCommonTest):
pass pass

Loading…
Cancel
Save