From 116fb9be614d5e22a3803ffdcf40858bc7368777 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Mon, 18 Jul 2022 20:56:33 +0800 Subject: [PATCH] Fix bugs --- paddlers/transforms/__init__.py | 18 ++++++++++--- paddlers/transforms/operators.py | 20 +++++++------- tests/deploy/test_predictor.py | 10 +++---- tests/transforms/test_operators.py | 42 ++++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 21 deletions(-) diff --git a/paddlers/transforms/__init__.py b/paddlers/transforms/__init__.py index 9977899..c5ad12e 100644 --- a/paddlers/transforms/__init__.py +++ b/paddlers/transforms/__init__.py @@ -23,15 +23,27 @@ from paddlers import transforms as T def decode_image(im_path, to_rgb=True, to_uint8=True, - decode_rgb=True, - decode_sar=False): + decode_bgr=True, + decode_sar=True): + """ + Decode an image. + + Args: + to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True. + to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True. + decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image. + Defaults to True. + decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a + SAR image, set this argument to True. Defaults to True. + """ + # Do a presence check. `osp.exists` assumes `im_path` is a path-like object. if not osp.exists(im_path): raise ValueError(f"{im_path} does not exist!") decoder = T.DecodeImg( to_rgb=to_rgb, to_uint8=to_uint8, - decode_rgb=decode_rgb, + decode_bgr=decode_bgr, decode_sar=decode_sar) # Deepcopy to avoid inplace modification sample = {'image': copy.deepcopy(im_path)} diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index f053264..bc36637 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -126,19 +126,21 @@ class DecodeImg(Transform): Args: to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True. to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True. - decode_rgb (bool, optional): If the image to decode is a non-geo RGB image (e.g., jpeg images), set this argument to True. Defaults to True. - decode_sar (bool, optional): If the image to decode is a SAR image, set this argument to True. Defaults to False. + decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image. + Defaults to True. + decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a + SAR image, set this argument to True. Defaults to True. """ def __init__(self, to_rgb=True, to_uint8=True, - decode_rgb=True, - decode_sar=False): + decode_bgr=True, + decode_sar=True): super(DecodeImg, self).__init__() self.to_rgb = to_rgb self.to_uint8 = to_uint8 - self.decode_rgb = decode_rgb + self.decode_bgr = decode_bgr self.decode_sar = decode_sar def read_img(self, img_path): @@ -159,11 +161,7 @@ class DecodeImg(Transform): if dataset == None: raise IOError('Can not open', img_path) im_data = dataset.ReadAsArray() - if self.decode_sar: - if im_data.ndim != 2: - raise ValueError( - f"SAR images should have exactly 2 channels, but the image has {im_data.ndim} channels." - ) + if im_data.ndim == 2 and self.decode_sar: im_data = to_intensity(im_data) # is read SAR im_data = im_data[:, :, np.newaxis] else: @@ -171,7 +169,7 @@ class DecodeImg(Transform): im_data = im_data.transpose((1, 2, 0)) return im_data elif img_format in ['jpeg', 'bmp', 'png', 'jpg']: - if self.decode_rgb: + if self.decode_bgr: return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR) else: diff --git a/tests/deploy/test_predictor.py b/tests/deploy/test_predictor.py index fc1eec7..141556b 100644 --- a/tests/deploy/test_predictor.py +++ b/tests/deploy/test_predictor.py @@ -100,13 +100,9 @@ class TestPredictor(CommonTest): for key in dict_.keys(): if key in ignore_keys: continue - if isinstance(dict_[key], (list, dict)): - self.check_dict_equal( - dict_[key], expected_dict[key], ignore_keys=ignore_keys) - else: - # Use higher tolerance - self.check_output_equal( - dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6) + # Use higher tolerance + self.check_output_equal( + dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6) @TestPredictor.add_tests diff --git a/tests/transforms/test_operators.py b/tests/transforms/test_operators.py index 7b12955..f6320ae 100644 --- a/tests/transforms/test_operators.py +++ b/tests/transforms/test_operators.py @@ -116,6 +116,18 @@ def _is_mt(sample): return 'image2' in sample +def _is_seg(sample): + return 'mask' in sample and 'image2' not in sample + + +def _is_det(sample): + return 'gt_bbox' in sample or 'gt_poly' in sample + + +def _is_clas(sample): + return 'label' in sample + + _filter_only_optical = _InputFilter([_is_optical]) _filter_only_sar = _InputFilter([_is_sar]) _filter_only_multispectral = _InputFilter([_is_multispectral]) @@ -123,6 +135,7 @@ _filter_no_multispectral = _filter_only_optical | _filter_only_sar _filter_no_sar = _filter_only_optical | _filter_only_multispectral _filter_no_optical = _filter_only_sar | _filter_only_multispectral _filter_only_mt = _InputFilter([_is_mt]) +_filter_no_det = _InputFilter([_is_seg, _is_clas, _is_mt]) OP2FILTER = { 'RandomSwap': _filter_only_mt, @@ -262,6 +275,35 @@ class TestTransform(CpuCommonTest): keep_ratio=True) test_func_keep_ratio(self) + def test_RandomFlipOrRotate(self): + def _in_hook(sample): + if 'image2' in sample: + self.im_diff = ( + sample['image'] - sample['image2']).astype('float64') + elif 'mask' in sample: + self.im_diff = ( + sample['image'][..., 0] - sample['mask']).astype('float64') + return sample + + def _out_hook(sample): + im_diff = None + if 'image2' in sample: + im_diff = (sample['image'] - sample['image2']).astype('float64') + elif 'mask' in sample: + im_diff = ( + sample['image'][..., 0] - sample['mask']).astype('float64') + if im_diff is not None: + self.check_output_equal(im_diff.max(), self.im_diff.max()) + self.check_output_equal(im_diff.min(), self.im_diff.min()) + return sample + + test_func = make_test_func( + T.RandomFlipOrRotate, + in_hook=_in_hook, + out_hook=_out_hook, + filter_=_filter_no_det) + test_func(self) + class TestCompose(CpuCommonTest): pass