diff --git a/paddlers/transforms/functions.py b/paddlers/transforms/functions.py index 5550e33..9e70676 100644 --- a/paddlers/transforms/functions.py +++ b/paddlers/transforms/functions.py @@ -604,6 +604,55 @@ def match_by_regression(im, ref, pif_loc=None): return matched +def match_lf_components(im, ref, lf_ratio=0.01): + """ + Match the low-frequency components of two images. + + Args: + im (np.ndarray): Input image. + ref (np.ndarray): Reference image to match. `ref` must have the same shape + as `im`. + lf_ratio (float, optional): Proportion of frequence components that should + be recognized as low-frequency components in the frequency domain. + Default: 0.01. + + Returns: + np.ndarray: Transformed input image. + + Raises: + ValueError: When the shape of `ref` differs from that of `im`. + """ + + def _replace_lf(im, ref, lf_ratio): + h, w = im.shape + h_lf, w_lf = int(h // 2 * lf_ratio), int(w // 2 * lf_ratio) + freq_im = np.fft.fft2(im) + freq_ref = np.fft.fft2(ref) + if h_lf > 0: + freq_im[:h_lf] = freq_ref[:h_lf] + freq_im[-h_lf:] = freq_ref[-h_lf:] + if w_lf > 0: + freq_im[:, :w_lf] = freq_ref[:, :w_lf] + freq_im[:, -w_lf:] = freq_ref[:, -w_lf:] + recon_im = np.fft.ifft2(freq_im) + recon_im = np.abs(recon_im) + return recon_im + + if im.shape != ref.shape: + raise ValueError("Image and Reference must have the same shape!") + + if im.ndim > 2: + # Multiple channels + matched = np.empty(im.shape, dtype=im.dtype) + for ch in range(im.shape[-1]): + matched[..., ch] = _replace_lf(im[..., ch], ref[..., ch], lf_ratio) + else: + # Single channel + matched = _replace_lf(im, ref, lf_ratio).astype(im.dtype) + + return matched + + def inv_pca(im, joblib_path): """ Perform inverse PCA transformation. diff --git a/paddlers/transforms/operators.py b/paddlers/transforms/operators.py index fe34878..0a3fd5c 100644 --- a/paddlers/transforms/operators.py +++ b/paddlers/transforms/operators.py @@ -27,14 +27,8 @@ from PIL import Image from joblib import load import paddlers +import paddlers.transforms.functions as F import paddlers.transforms.indices as indices -from .functions import ( - normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, - horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, - vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle, - resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8, - img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape, - match_by_regression, match_histograms) __all__ = [ "Compose", @@ -243,7 +237,7 @@ class DecodeImg(Transform): raise IOError('Cannot open', img_path) im_data = dataset.ReadAsArray() if im_data.ndim == 2 and self.decode_sar: - im_data = to_intensity(im_data) + im_data = F.to_intensity(im_data) im_data = im_data[:, :, np.newaxis] else: if im_data.ndim == 3: @@ -287,7 +281,7 @@ class DecodeImg(Transform): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.to_uint8: - image = to_uint8(image) + image = F.to_uint8(image) if self.read_geo_info: return image, geo_info_dict @@ -442,15 +436,15 @@ class Resize(Transform): im_scale_x, im_scale_y = scale resized_segms = [] for segm in segms: - if is_poly(segm): + if F.is_poly(segm): # Polygon format resized_segms.append([ - resize_poly(poly, im_scale_x, im_scale_y) for poly in segm + F.resize_poly(poly, im_scale_x, im_scale_y) for poly in segm ]) else: # RLE format resized_segms.append( - resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y)) + F.resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y)) return resized_segms @@ -495,7 +489,7 @@ class Resize(Transform): # For SR tasks sample['target'] = self.apply_im( sample['target'], interp, - calc_hr_shape(target_size, sample['sr_factor'])) + F.calc_hr_shape(target_size, sample['sr_factor'])) else: # For non-SR tasks sample['target'] = self.apply_im(sample['target'], interp, @@ -692,16 +686,16 @@ class RandomFlipOrRotate(Transform): def apply_im(self, image, mode_id, flip_mode=True): if flip_mode: - image = img_flip(image, mode_id) + image = F.img_flip(image, mode_id) else: - image = img_simple_rotate(image, mode_id) + image = F.img_simple_rotate(image, mode_id) return image def apply_mask(self, mask, mode_id, flip_mode=True): if flip_mode: - mask = img_flip(mask, mode_id) + mask = F.img_flip(mask, mode_id) else: - mask = img_simple_rotate(mask, mode_id) + mask = F.img_simple_rotate(mask, mode_id) return mask def apply_bbox(self, bbox, mode_id, flip_mode=True): @@ -817,11 +811,11 @@ class RandomHorizontalFlip(Transform): self.prob = prob def apply_im(self, image): - image = horizontal_flip(image) + image = F.horizontal_flip(image) return image def apply_mask(self, mask): - mask = horizontal_flip(mask) + mask = F.horizontal_flip(mask) return mask def apply_bbox(self, bbox, width): @@ -834,13 +828,13 @@ class RandomHorizontalFlip(Transform): def apply_segm(self, segms, height, width): flipped_segms = [] for segm in segms: - if is_poly(segm): + if F.is_poly(segm): # Polygon format flipped_segms.append( - [horizontal_flip_poly(poly, width) for poly in segm]) + [F.horizontal_flip_poly(poly, width) for poly in segm]) else: # RLE format - flipped_segms.append(horizontal_flip_rle(segm, height, width)) + flipped_segms.append(F.horizontal_flip_rle(segm, height, width)) return flipped_segms def apply(self, sample): @@ -877,11 +871,11 @@ class RandomVerticalFlip(Transform): self.prob = prob def apply_im(self, image): - image = vertical_flip(image) + image = F.vertical_flip(image) return image def apply_mask(self, mask): - mask = vertical_flip(mask) + mask = F.vertical_flip(mask) return mask def apply_bbox(self, bbox, height): @@ -894,13 +888,13 @@ class RandomVerticalFlip(Transform): def apply_segm(self, segms, height, width): flipped_segms = [] for segm in segms: - if is_poly(segm): + if F.is_poly(segm): # Polygon format flipped_segms.append( - [vertical_flip_poly(poly, height) for poly in segm]) + [F.vertical_flip_poly(poly, height) for poly in segm]) else: # RLE format - flipped_segms.append(vertical_flip_rle(segm, height, width)) + flipped_segms.append(F.vertical_flip_rle(segm, height, width)) return flipped_segms def apply(self, sample): @@ -978,7 +972,7 @@ class Normalize(Transform): mean = np.asarray( self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :] std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :] - image = normalize(image, mean, std, self.min_val, self.max_val) + image = F.normalize(image, mean, std, self.min_val, self.max_val) return image def apply(self, sample): @@ -1007,12 +1001,12 @@ class CenterCrop(Transform): self.crop_size = crop_size def apply_im(self, image): - image = center_crop(image, self.crop_size) + image = F.center_crop(image, self.crop_size) return image def apply_mask(self, mask): - mask = center_crop(mask, self.crop_size) + mask = F.center_crop(mask, self.crop_size) return mask def apply(self, sample): @@ -1159,12 +1153,12 @@ class RandomCrop(Transform): crop_segms = [] for id in valid_ids: segm = segms[id] - if is_poly(segm): + if F.is_poly(segm): # Polygon format - crop_segms.append(crop_poly(segm, crop)) + crop_segms.append(F.crop_poly(segm, crop)) else: # RLE format - crop_segms.append(crop_rle(segm, crop, height, width)) + crop_segms.append(F.crop_rle(segm, crop, height, width)) return crop_segms @@ -1196,7 +1190,7 @@ class RandomCrop(Transform): delete_id = list() valid_polys = list() for idx, poly in enumerate(crop_polys): - if not crop_poly: + if not poly: delete_id.append(idx) else: valid_polys.append(poly) @@ -1231,7 +1225,7 @@ class RandomCrop(Transform): if 'sr_factor' in sample: sample['target'] = self.apply_im( sample['target'], - calc_hr_shape(crop_box, sample['sr_factor'])) + F.calc_hr_shape(crop_box, sample['sr_factor'])) else: sample['target'] = self.apply_im(sample['image'], crop_box) @@ -1393,14 +1387,14 @@ class Pad(Transform): h, w = size expanded_segms = [] for segm in segms: - if is_poly(segm): + if F.is_poly(segm): # Polygon format expanded_segms.append( - [expand_poly(poly, x, y) for poly in segm]) + [F.expand_poly(poly, x, y) for poly in segm]) else: # RLE format expanded_segms.append( - expand_rle(segm, x, y, height, width, h, w)) + F.expand_rle(segm, x, y, height, width, h, w)) return expanded_segms def _get_offsets(self, im_h, im_w, h, w): @@ -1450,7 +1444,7 @@ class Pad(Transform): sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w]) if 'target' in sample: if 'sr_factor' in sample: - hr_shape = calc_hr_shape((h, w), sample['sr_factor']) + hr_shape = F.calc_hr_shape((h, w), sample['sr_factor']) hr_offsets = self._get_offsets(*sample['target'].shape[:2], *hr_shape) sample['target'] = self.apply_im(sample['target'], hr_offsets, @@ -1757,7 +1751,7 @@ class Dehaze(Transform): self.gamma = gamma def apply_im(self, image): - image = dehaze(image, self.gamma) + image = F.dehaze(image, self.gamma) return image def apply(self, sample): @@ -1819,7 +1813,7 @@ class SelectBand(Transform): self.apply_to_tar = apply_to_tar def apply_im(self, image): - image = select_bands(image, self.band_list) + image = F.select_bands(image, self.band_list) return image def apply(self, sample): @@ -1944,10 +1938,10 @@ class RandomSwap(Transform): class ReloadMask(Transform): def apply(self, sample): - sample['mask'] = decode_seg_mask(sample['mask_ori']) + sample['mask'] = F.decode_seg_mask(sample['mask_ori']) if 'aux_masks' in sample: sample['aux_masks'] = list( - map(decode_seg_mask, sample['aux_masks_ori'])) + map(F.decode_seg_mask, sample['aux_masks_ori'])) return sample @@ -1987,18 +1981,21 @@ class MatchRadiance(Transform): Args: method (str, optional): Method used to match the radiance of the - bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands - for histogram matching and 'lsr' stands for least-squares - regression. Default: 'hist'. + bi-temporal images. Choices are {'hist', 'lsr', 'fft}. 'hist' + stands for histogram matching, 'lsr' stands for least-squares + regression, and 'fft' replaces the low-frequency components of + the image to match the reference image. Default: 'hist'. """ def __init__(self, method='hist'): super(MatchRadiance, self).__init__() if method == 'hist': - self._match_func = match_histograms + self._match_func = F.match_histograms elif method == 'lsr': - self._match_func = match_by_regression + self._match_func = F.match_by_regression + elif method == 'fft': + self._match_func = F.match_lf_components else: raise ValueError( "{} is not a supported radiometric correction method.".format( diff --git a/tests/transforms/test_operators.py b/tests/transforms/test_operators.py index bb354eb..a892e49 100644 --- a/tests/transforms/test_operators.py +++ b/tests/transforms/test_operators.py @@ -390,6 +390,9 @@ class TestTransform(CpuCommonTest): test_lsr = make_test_func( T.MatchRadiance, 'lsr', _filter=_filter_only_mt) test_lsr(self) + test_fft = make_test_func( + T.MatchRadiance, 'fft', _filter=_filter_only_mt) + test_fft(self) class TestCompose(CpuCommonTest):