Merge pull request #39 from Bobholamovic/add_fft

[Feat] Enhance `MatchRadiance` and Add FFT Mode
own
cc 2 years ago committed by GitHub
commit 820e5e66d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      .github/PULL_REQUEST_TEMPLATE.md
  2. 49
      paddlers/transforms/functions.py
  3. 111
      paddlers/transforms/operators.py
  4. 3
      tests/transforms/test_operators.py

@ -1,8 +1,8 @@
### PR types
<!-- One of [ New features | Bug fixes | Function optimization | Performance optimization | Breaking changes | Others ] -->
<!-- One of [ New features | Bug fixes | Code refactoring | Performance optimization | Breaking changes | Others ] -->
### PR changes
<!-- One of [ Models | APIs | Docs | Others ] -->
<!-- One of [ Models | Transforms | Tools | Examples | Docs | Tests | Others ] -->
### Description
<!-- Describe what this PR does -->

@ -604,6 +604,55 @@ def match_by_regression(im, ref, pif_loc=None):
return matched
def match_lf_components(im, ref, lf_ratio=0.01):
"""
Match the low-frequency components of two images.
Args:
im (np.ndarray): Input image.
ref (np.ndarray): Reference image to match. `ref` must have the same shape
as `im`.
lf_ratio (float, optional): Proportion of frequence components that should
be recognized as low-frequency components in the frequency domain.
Default: 0.01.
Returns:
np.ndarray: Transformed input image.
Raises:
ValueError: When the shape of `ref` differs from that of `im`.
"""
def _replace_lf(im, ref, lf_ratio):
h, w = im.shape
h_lf, w_lf = int(h // 2 * lf_ratio), int(w // 2 * lf_ratio)
freq_im = np.fft.fft2(im)
freq_ref = np.fft.fft2(ref)
if h_lf > 0:
freq_im[:h_lf] = freq_ref[:h_lf]
freq_im[-h_lf:] = freq_ref[-h_lf:]
if w_lf > 0:
freq_im[:, :w_lf] = freq_ref[:, :w_lf]
freq_im[:, -w_lf:] = freq_ref[:, -w_lf:]
recon_im = np.fft.ifft2(freq_im)
recon_im = np.abs(recon_im)
return recon_im
if im.shape != ref.shape:
raise ValueError("Image and Reference must have the same shape!")
if im.ndim > 2:
# Multiple channels
matched = np.empty(im.shape, dtype=im.dtype)
for ch in range(im.shape[-1]):
matched[..., ch] = _replace_lf(im[..., ch], ref[..., ch], lf_ratio)
else:
# Single channel
matched = _replace_lf(im, ref, lf_ratio).astype(im.dtype)
return matched
def inv_pca(im, joblib_path):
"""
Perform inverse PCA transformation.

@ -27,15 +27,9 @@ from PIL import Image
from joblib import load
import paddlers
import paddlers.transforms.functions as F
import paddlers.transforms.indices as indices
import paddlers.transforms.satellites as satellites
from .functions import (
normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape,
match_by_regression, match_histograms)
__all__ = [
"Compose",
@ -248,7 +242,7 @@ class DecodeImg(Transform):
raise IOError('Cannot open', img_path)
im_data = dataset.ReadAsArray()
if im_data.ndim == 2 and self.decode_sar:
im_data = to_intensity(im_data)
im_data = F.to_intensity(im_data)
im_data = im_data[:, :, np.newaxis]
else:
if im_data.ndim == 3:
@ -292,7 +286,7 @@ class DecodeImg(Transform):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.to_uint8:
image = to_uint8(image, stretch=self.use_stretch)
image = F.to_uint8(image, stretch=self.use_stretch)
if self.read_geo_info:
return image, geo_info_dict
@ -447,15 +441,15 @@ class Resize(Transform):
im_scale_x, im_scale_y = scale
resized_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
resized_segms.append([
resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
F.resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
])
else:
# RLE format
resized_segms.append(
resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
F.resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
return resized_segms
@ -500,7 +494,7 @@ class Resize(Transform):
# For SR tasks
sample['target'] = self.apply_im(
sample['target'], interp,
calc_hr_shape(target_size, sample['sr_factor']))
F.calc_hr_shape(target_size, sample['sr_factor']))
else:
# For non-SR tasks
sample['target'] = self.apply_im(sample['target'], interp,
@ -697,16 +691,16 @@ class RandomFlipOrRotate(Transform):
def apply_im(self, image, mode_id, flip_mode=True):
if flip_mode:
image = img_flip(image, mode_id)
image = F.img_flip(image, mode_id)
else:
image = img_simple_rotate(image, mode_id)
image = F.img_simple_rotate(image, mode_id)
return image
def apply_mask(self, mask, mode_id, flip_mode=True):
if flip_mode:
mask = img_flip(mask, mode_id)
mask = F.img_flip(mask, mode_id)
else:
mask = img_simple_rotate(mask, mode_id)
mask = F.img_simple_rotate(mask, mode_id)
return mask
def apply_bbox(self, bbox, mode_id, flip_mode=True):
@ -822,11 +816,11 @@ class RandomHorizontalFlip(Transform):
self.prob = prob
def apply_im(self, image):
image = horizontal_flip(image)
image = F.horizontal_flip(image)
return image
def apply_mask(self, mask):
mask = horizontal_flip(mask)
mask = F.horizontal_flip(mask)
return mask
def apply_bbox(self, bbox, width):
@ -839,13 +833,13 @@ class RandomHorizontalFlip(Transform):
def apply_segm(self, segms, height, width):
flipped_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
flipped_segms.append(
[horizontal_flip_poly(poly, width) for poly in segm])
[F.horizontal_flip_poly(poly, width) for poly in segm])
else:
# RLE format
flipped_segms.append(horizontal_flip_rle(segm, height, width))
flipped_segms.append(F.horizontal_flip_rle(segm, height, width))
return flipped_segms
def apply(self, sample):
@ -882,11 +876,11 @@ class RandomVerticalFlip(Transform):
self.prob = prob
def apply_im(self, image):
image = vertical_flip(image)
image = F.vertical_flip(image)
return image
def apply_mask(self, mask):
mask = vertical_flip(mask)
mask = F.vertical_flip(mask)
return mask
def apply_bbox(self, bbox, height):
@ -899,13 +893,13 @@ class RandomVerticalFlip(Transform):
def apply_segm(self, segms, height, width):
flipped_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
flipped_segms.append(
[vertical_flip_poly(poly, height) for poly in segm])
[F.vertical_flip_poly(poly, height) for poly in segm])
else:
# RLE format
flipped_segms.append(vertical_flip_rle(segm, height, width))
flipped_segms.append(F.vertical_flip_rle(segm, height, width))
return flipped_segms
def apply(self, sample):
@ -983,7 +977,7 @@ class Normalize(Transform):
mean = np.asarray(
self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
image = normalize(image, mean, std, self.min_val, self.max_val)
image = F.normalize(image, mean, std, self.min_val, self.max_val)
return image
def apply(self, sample):
@ -1012,12 +1006,12 @@ class CenterCrop(Transform):
self.crop_size = crop_size
def apply_im(self, image):
image = center_crop(image, self.crop_size)
image = F.center_crop(image, self.crop_size)
return image
def apply_mask(self, mask):
mask = center_crop(mask, self.crop_size)
mask = F.center_crop(mask, self.crop_size)
return mask
def apply(self, sample):
@ -1164,12 +1158,12 @@ class RandomCrop(Transform):
crop_segms = []
for id in valid_ids:
segm = segms[id]
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
crop_segms.append(crop_poly(segm, crop))
crop_segms.append(F.crop_poly(segm, crop))
else:
# RLE format
crop_segms.append(crop_rle(segm, crop, height, width))
crop_segms.append(F.crop_rle(segm, crop, height, width))
return crop_segms
@ -1201,7 +1195,7 @@ class RandomCrop(Transform):
delete_id = list()
valid_polys = list()
for idx, poly in enumerate(crop_polys):
if not crop_poly:
if not poly:
delete_id.append(idx)
else:
valid_polys.append(poly)
@ -1236,7 +1230,7 @@ class RandomCrop(Transform):
if 'sr_factor' in sample:
sample['target'] = self.apply_im(
sample['target'],
calc_hr_shape(crop_box, sample['sr_factor']))
F.calc_hr_shape(crop_box, sample['sr_factor']))
else:
sample['target'] = self.apply_im(sample['image'], crop_box)
@ -1398,14 +1392,14 @@ class Pad(Transform):
h, w = size
expanded_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
expanded_segms.append(
[expand_poly(poly, x, y) for poly in segm])
[F.expand_poly(poly, x, y) for poly in segm])
else:
# RLE format
expanded_segms.append(
expand_rle(segm, x, y, height, width, h, w))
F.expand_rle(segm, x, y, height, width, h, w))
return expanded_segms
def _get_offsets(self, im_h, im_w, h, w):
@ -1455,7 +1449,7 @@ class Pad(Transform):
sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
if 'target' in sample:
if 'sr_factor' in sample:
hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
hr_shape = F.calc_hr_shape((h, w), sample['sr_factor'])
hr_offsets = self._get_offsets(*sample['target'].shape[:2],
*hr_shape)
sample['target'] = self.apply_im(sample['target'], hr_offsets,
@ -1762,7 +1756,7 @@ class Dehaze(Transform):
self.gamma = gamma
def apply_im(self, image):
image = dehaze(image, self.gamma)
image = F.dehaze(image, self.gamma)
return image
def apply(self, sample):
@ -1824,7 +1818,7 @@ class SelectBand(Transform):
self.apply_to_tar = apply_to_tar
def apply_im(self, image):
image = select_bands(image, self.band_list)
image = F.select_bands(image, self.band_list)
return image
def apply(self, sample):
@ -1917,11 +1911,11 @@ class _Permute(Transform):
super(_Permute, self).__init__()
def apply(self, sample):
sample['image'] = permute(sample['image'], False)
sample['image'] = F.permute(sample['image'], False)
if 'image2' in sample:
sample['image2'] = permute(sample['image2'], False)
sample['image2'] = F.permute(sample['image2'], False)
if 'target' in sample:
sample['target'] = permute(sample['target'], False)
sample['target'] = F.permute(sample['target'], False)
return sample
@ -1949,10 +1943,10 @@ class RandomSwap(Transform):
class ReloadMask(Transform):
def apply(self, sample):
sample['mask'] = decode_seg_mask(sample['mask_ori'])
sample['mask'] = F.decode_seg_mask(sample['mask_ori'])
if 'aux_masks' in sample:
sample['aux_masks'] = list(
map(decode_seg_mask, sample['aux_masks_ori']))
map(F.decode_seg_mask, sample['aux_masks_ori']))
return sample
@ -2006,18 +2000,21 @@ class MatchRadiance(Transform):
Args:
method (str, optional): Method used to match the radiance of the
bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands
for histogram matching and 'lsr' stands for least-squares
regression. Default: 'hist'.
bi-temporal images. Choices are {'hist', 'lsr', 'fft}. 'hist'
stands for histogram matching, 'lsr' stands for least-squares
regression, and 'fft' replaces the low-frequency components of
the image to match the reference image. Default: 'hist'.
"""
def __init__(self, method='hist'):
super(MatchRadiance, self).__init__()
if method == 'hist':
self._match_func = match_histograms
self._match_func = F.match_histograms
elif method == 'lsr':
self._match_func = match_by_regression
self._match_func = F.match_by_regression
elif method == 'fft':
self._match_func = F.match_lf_components
else:
raise ValueError(
"{} is not a supported radiometric correction method.".format(
@ -2049,7 +2046,7 @@ class ArrangeSegmenter(Arrange):
mask = sample['mask']
mask = mask.astype('int64')
image = permute(sample['image'], False)
image = F.permute(sample['image'], False)
if self.mode == 'train':
return image, mask
if self.mode == 'eval':
@ -2064,8 +2061,8 @@ class ArrangeChangeDetector(Arrange):
mask = sample['mask']
mask = mask.astype('int64')
image_t1 = permute(sample['image'], False)
image_t2 = permute(sample['image2'], False)
image_t1 = F.permute(sample['image'], False)
image_t2 = F.permute(sample['image2'], False)
if self.mode == 'train':
masks = [mask]
if 'aux_masks' in sample:
@ -2082,7 +2079,7 @@ class ArrangeChangeDetector(Arrange):
class ArrangeClassifier(Arrange):
def apply(self, sample):
image = permute(sample['image'], False)
image = F.permute(sample['image'], False)
if self.mode in ['train', 'eval']:
return image, sample['label']
else:
@ -2099,8 +2096,8 @@ class ArrangeDetector(Arrange):
class ArrangeRestorer(Arrange):
def apply(self, sample):
if 'target' in sample:
target = permute(sample['target'], False)
image = permute(sample['image'], False)
target = F.permute(sample['target'], False)
image = F.permute(sample['image'], False)
if self.mode == 'train':
return image, target
if self.mode == 'eval':

@ -400,6 +400,9 @@ class TestTransform(CpuCommonTest):
test_lsr = make_test_func(
T.MatchRadiance, 'lsr', _filter=_filter_only_mt)
test_lsr(self)
test_fft = make_test_func(
T.MatchRadiance, 'fft', _filter=_filter_only_mt)
test_fft(self)
class TestCompose(CpuCommonTest):

Loading…
Cancel
Save