own
Bobholamovic 2 years ago
parent 61f818411c
commit 116fb9be61
  1. 18
      paddlers/transforms/__init__.py
  2. 20
      paddlers/transforms/operators.py
  3. 10
      tests/deploy/test_predictor.py
  4. 42
      tests/transforms/test_operators.py

@ -23,15 +23,27 @@ from paddlers import transforms as T
def decode_image(im_path,
to_rgb=True,
to_uint8=True,
decode_rgb=True,
decode_sar=False):
decode_bgr=True,
decode_sar=True):
"""
Decode an image.
Args:
to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g. jpeg images) as a BGR image.
Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a
SAR image, set this argument to True. Defaults to True.
"""
# Do a presence check. `osp.exists` assumes `im_path` is a path-like object.
if not osp.exists(im_path):
raise ValueError(f"{im_path} does not exist!")
decoder = T.DecodeImg(
to_rgb=to_rgb,
to_uint8=to_uint8,
decode_rgb=decode_rgb,
decode_bgr=decode_bgr,
decode_sar=decode_sar)
# Deepcopy to avoid inplace modification
sample = {'image': copy.deepcopy(im_path)}

@ -126,19 +126,21 @@ class DecodeImg(Transform):
Args:
to_rgb (bool, optional): If True, convert input image(s) from BGR format to RGB format. Defaults to True.
to_uint8 (bool, optional): If True, quantize and convert decoded image(s) to uint8 type. Defaults to True.
decode_rgb (bool, optional): If the image to decode is a non-geo RGB image (e.g., jpeg images), set this argument to True. Defaults to True.
decode_sar (bool, optional): If the image to decode is a SAR image, set this argument to True. Defaults to False.
decode_bgr (bool, optional): If True, automatically interpret a non-geo image (e.g., jpeg images) as a BGR image.
Defaults to True.
decode_sar (bool, optional): If True, automatically interpret a two-channel geo image (e.g. geotiff images) as a
SAR image, set this argument to True. Defaults to True.
"""
def __init__(self,
to_rgb=True,
to_uint8=True,
decode_rgb=True,
decode_sar=False):
decode_bgr=True,
decode_sar=True):
super(DecodeImg, self).__init__()
self.to_rgb = to_rgb
self.to_uint8 = to_uint8
self.decode_rgb = decode_rgb
self.decode_bgr = decode_bgr
self.decode_sar = decode_sar
def read_img(self, img_path):
@ -159,11 +161,7 @@ class DecodeImg(Transform):
if dataset == None:
raise IOError('Can not open', img_path)
im_data = dataset.ReadAsArray()
if self.decode_sar:
if im_data.ndim != 2:
raise ValueError(
f"SAR images should have exactly 2 channels, but the image has {im_data.ndim} channels."
)
if im_data.ndim == 2 and self.decode_sar:
im_data = to_intensity(im_data) # is read SAR
im_data = im_data[:, :, np.newaxis]
else:
@ -171,7 +169,7 @@ class DecodeImg(Transform):
im_data = im_data.transpose((1, 2, 0))
return im_data
elif img_format in ['jpeg', 'bmp', 'png', 'jpg']:
if self.decode_rgb:
if self.decode_bgr:
return cv2.imread(img_path, cv2.IMREAD_ANYDEPTH |
cv2.IMREAD_ANYCOLOR | cv2.IMREAD_COLOR)
else:

@ -100,13 +100,9 @@ class TestPredictor(CommonTest):
for key in dict_.keys():
if key in ignore_keys:
continue
if isinstance(dict_[key], (list, dict)):
self.check_dict_equal(
dict_[key], expected_dict[key], ignore_keys=ignore_keys)
else:
# Use higher tolerance
self.check_output_equal(
dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
# Use higher tolerance
self.check_output_equal(
dict_[key], expected_dict[key], rtol=1.e-4, atol=1.e-6)
@TestPredictor.add_tests

@ -116,6 +116,18 @@ def _is_mt(sample):
return 'image2' in sample
def _is_seg(sample):
return 'mask' in sample and 'image2' not in sample
def _is_det(sample):
return 'gt_bbox' in sample or 'gt_poly' in sample
def _is_clas(sample):
return 'label' in sample
_filter_only_optical = _InputFilter([_is_optical])
_filter_only_sar = _InputFilter([_is_sar])
_filter_only_multispectral = _InputFilter([_is_multispectral])
@ -123,6 +135,7 @@ _filter_no_multispectral = _filter_only_optical | _filter_only_sar
_filter_no_sar = _filter_only_optical | _filter_only_multispectral
_filter_no_optical = _filter_only_sar | _filter_only_multispectral
_filter_only_mt = _InputFilter([_is_mt])
_filter_no_det = _InputFilter([_is_seg, _is_clas, _is_mt])
OP2FILTER = {
'RandomSwap': _filter_only_mt,
@ -262,6 +275,35 @@ class TestTransform(CpuCommonTest):
keep_ratio=True)
test_func_keep_ratio(self)
def test_RandomFlipOrRotate(self):
def _in_hook(sample):
if 'image2' in sample:
self.im_diff = (
sample['image'] - sample['image2']).astype('float64')
elif 'mask' in sample:
self.im_diff = (
sample['image'][..., 0] - sample['mask']).astype('float64')
return sample
def _out_hook(sample):
im_diff = None
if 'image2' in sample:
im_diff = (sample['image'] - sample['image2']).astype('float64')
elif 'mask' in sample:
im_diff = (
sample['image'][..., 0] - sample['mask']).astype('float64')
if im_diff is not None:
self.check_output_equal(im_diff.max(), self.im_diff.max())
self.check_output_equal(im_diff.min(), self.im_diff.min())
return sample
test_func = make_test_func(
T.RandomFlipOrRotate,
in_hook=_in_hook,
out_hook=_out_hook,
filter_=_filter_no_det)
test_func(self)
class TestCompose(CpuCommonTest):
pass

Loading…
Cancel
Save