Add match_lf_components code

own
Bobholamovic 2 years ago
parent 92933a8d4a
commit bd7c671e21
  1. 49
      paddlers/transforms/functions.py
  2. 93
      paddlers/transforms/operators.py
  3. 3
      tests/transforms/test_operators.py

@ -604,6 +604,55 @@ def match_by_regression(im, ref, pif_loc=None):
return matched
def match_lf_components(im, ref, lf_ratio=0.01):
"""
Match the low-frequency components of two images.
Args:
im (np.ndarray): Input image.
ref (np.ndarray): Reference image to match. `ref` must have the same shape
as `im`.
lf_ratio (float, optional): Proportion of frequence components that should
be recognized as low-frequency components in the frequency domain.
Default: 0.01.
Returns:
np.ndarray: Transformed input image.
Raises:
ValueError: When the shape of `ref` differs from that of `im`.
"""
def _replace_lf(im, ref, lf_ratio):
h, w = im.shape
h_lf, w_lf = int(h // 2 * lf_ratio), int(w // 2 * lf_ratio)
freq_im = np.fft.fft2(im)
freq_ref = np.fft.fft2(ref)
if h_lf > 0:
freq_im[:h_lf] = freq_ref[:h_lf]
freq_im[-h_lf:] = freq_ref[-h_lf:]
if w_lf > 0:
freq_im[:, :w_lf] = freq_ref[:, :w_lf]
freq_im[:, -w_lf:] = freq_ref[:, -w_lf:]
recon_im = np.fft.ifft2(freq_im)
recon_im = np.abs(recon_im)
return recon_im
if im.shape != ref.shape:
raise ValueError("Image and Reference must have the same shape!")
if im.ndim > 2:
# Multiple channels
matched = np.empty(im.shape, dtype=im.dtype)
for ch in range(im.shape[-1]):
matched[..., ch] = _replace_lf(im[..., ch], ref[..., ch], lf_ratio)
else:
# Single channel
matched = _replace_lf(im, ref, lf_ratio).astype(im.dtype)
return matched
def inv_pca(im, joblib_path):
"""
Perform inverse PCA transformation.

@ -27,14 +27,8 @@ from PIL import Image
from joblib import load
import paddlers
import paddlers.transforms.functions as F
import paddlers.transforms.indices as indices
from .functions import (
normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly,
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly,
vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle,
resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8,
img_flip, img_simple_rotate, decode_seg_mask, calc_hr_shape,
match_by_regression, match_histograms)
__all__ = [
"Compose",
@ -243,7 +237,7 @@ class DecodeImg(Transform):
raise IOError('Cannot open', img_path)
im_data = dataset.ReadAsArray()
if im_data.ndim == 2 and self.decode_sar:
im_data = to_intensity(im_data)
im_data = F.to_intensity(im_data)
im_data = im_data[:, :, np.newaxis]
else:
if im_data.ndim == 3:
@ -287,7 +281,7 @@ class DecodeImg(Transform):
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.to_uint8:
image = to_uint8(image)
image = F.to_uint8(image)
if self.read_geo_info:
return image, geo_info_dict
@ -442,15 +436,15 @@ class Resize(Transform):
im_scale_x, im_scale_y = scale
resized_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
resized_segms.append([
resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
F.resize_poly(poly, im_scale_x, im_scale_y) for poly in segm
])
else:
# RLE format
resized_segms.append(
resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
F.resize_rle(segm, im_h, im_w, im_scale_x, im_scale_y))
return resized_segms
@ -495,7 +489,7 @@ class Resize(Transform):
# For SR tasks
sample['target'] = self.apply_im(
sample['target'], interp,
calc_hr_shape(target_size, sample['sr_factor']))
F.calc_hr_shape(target_size, sample['sr_factor']))
else:
# For non-SR tasks
sample['target'] = self.apply_im(sample['target'], interp,
@ -692,16 +686,16 @@ class RandomFlipOrRotate(Transform):
def apply_im(self, image, mode_id, flip_mode=True):
if flip_mode:
image = img_flip(image, mode_id)
image = F.img_flip(image, mode_id)
else:
image = img_simple_rotate(image, mode_id)
image = F.img_simple_rotate(image, mode_id)
return image
def apply_mask(self, mask, mode_id, flip_mode=True):
if flip_mode:
mask = img_flip(mask, mode_id)
mask = F.img_flip(mask, mode_id)
else:
mask = img_simple_rotate(mask, mode_id)
mask = F.img_simple_rotate(mask, mode_id)
return mask
def apply_bbox(self, bbox, mode_id, flip_mode=True):
@ -817,11 +811,11 @@ class RandomHorizontalFlip(Transform):
self.prob = prob
def apply_im(self, image):
image = horizontal_flip(image)
image = F.horizontal_flip(image)
return image
def apply_mask(self, mask):
mask = horizontal_flip(mask)
mask = F.horizontal_flip(mask)
return mask
def apply_bbox(self, bbox, width):
@ -834,13 +828,13 @@ class RandomHorizontalFlip(Transform):
def apply_segm(self, segms, height, width):
flipped_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
flipped_segms.append(
[horizontal_flip_poly(poly, width) for poly in segm])
[F.horizontal_flip_poly(poly, width) for poly in segm])
else:
# RLE format
flipped_segms.append(horizontal_flip_rle(segm, height, width))
flipped_segms.append(F.horizontal_flip_rle(segm, height, width))
return flipped_segms
def apply(self, sample):
@ -877,11 +871,11 @@ class RandomVerticalFlip(Transform):
self.prob = prob
def apply_im(self, image):
image = vertical_flip(image)
image = F.vertical_flip(image)
return image
def apply_mask(self, mask):
mask = vertical_flip(mask)
mask = F.vertical_flip(mask)
return mask
def apply_bbox(self, bbox, height):
@ -894,13 +888,13 @@ class RandomVerticalFlip(Transform):
def apply_segm(self, segms, height, width):
flipped_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
flipped_segms.append(
[vertical_flip_poly(poly, height) for poly in segm])
[F.vertical_flip_poly(poly, height) for poly in segm])
else:
# RLE format
flipped_segms.append(vertical_flip_rle(segm, height, width))
flipped_segms.append(F.vertical_flip_rle(segm, height, width))
return flipped_segms
def apply(self, sample):
@ -978,7 +972,7 @@ class Normalize(Transform):
mean = np.asarray(
self.mean, dtype=np.float32)[np.newaxis, np.newaxis, :]
std = np.asarray(self.std, dtype=np.float32)[np.newaxis, np.newaxis, :]
image = normalize(image, mean, std, self.min_val, self.max_val)
image = F.normalize(image, mean, std, self.min_val, self.max_val)
return image
def apply(self, sample):
@ -1007,12 +1001,12 @@ class CenterCrop(Transform):
self.crop_size = crop_size
def apply_im(self, image):
image = center_crop(image, self.crop_size)
image = F.center_crop(image, self.crop_size)
return image
def apply_mask(self, mask):
mask = center_crop(mask, self.crop_size)
mask = F.center_crop(mask, self.crop_size)
return mask
def apply(self, sample):
@ -1159,12 +1153,12 @@ class RandomCrop(Transform):
crop_segms = []
for id in valid_ids:
segm = segms[id]
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
crop_segms.append(crop_poly(segm, crop))
crop_segms.append(F.crop_poly(segm, crop))
else:
# RLE format
crop_segms.append(crop_rle(segm, crop, height, width))
crop_segms.append(F.crop_rle(segm, crop, height, width))
return crop_segms
@ -1196,7 +1190,7 @@ class RandomCrop(Transform):
delete_id = list()
valid_polys = list()
for idx, poly in enumerate(crop_polys):
if not crop_poly:
if not poly:
delete_id.append(idx)
else:
valid_polys.append(poly)
@ -1231,7 +1225,7 @@ class RandomCrop(Transform):
if 'sr_factor' in sample:
sample['target'] = self.apply_im(
sample['target'],
calc_hr_shape(crop_box, sample['sr_factor']))
F.calc_hr_shape(crop_box, sample['sr_factor']))
else:
sample['target'] = self.apply_im(sample['image'], crop_box)
@ -1393,14 +1387,14 @@ class Pad(Transform):
h, w = size
expanded_segms = []
for segm in segms:
if is_poly(segm):
if F.is_poly(segm):
# Polygon format
expanded_segms.append(
[expand_poly(poly, x, y) for poly in segm])
[F.expand_poly(poly, x, y) for poly in segm])
else:
# RLE format
expanded_segms.append(
expand_rle(segm, x, y, height, width, h, w))
F.expand_rle(segm, x, y, height, width, h, w))
return expanded_segms
def _get_offsets(self, im_h, im_w, h, w):
@ -1450,7 +1444,7 @@ class Pad(Transform):
sample['gt_poly'], offsets, im_size=[im_h, im_w], size=[h, w])
if 'target' in sample:
if 'sr_factor' in sample:
hr_shape = calc_hr_shape((h, w), sample['sr_factor'])
hr_shape = F.calc_hr_shape((h, w), sample['sr_factor'])
hr_offsets = self._get_offsets(*sample['target'].shape[:2],
*hr_shape)
sample['target'] = self.apply_im(sample['target'], hr_offsets,
@ -1757,7 +1751,7 @@ class Dehaze(Transform):
self.gamma = gamma
def apply_im(self, image):
image = dehaze(image, self.gamma)
image = F.dehaze(image, self.gamma)
return image
def apply(self, sample):
@ -1819,7 +1813,7 @@ class SelectBand(Transform):
self.apply_to_tar = apply_to_tar
def apply_im(self, image):
image = select_bands(image, self.band_list)
image = F.select_bands(image, self.band_list)
return image
def apply(self, sample):
@ -1944,10 +1938,10 @@ class RandomSwap(Transform):
class ReloadMask(Transform):
def apply(self, sample):
sample['mask'] = decode_seg_mask(sample['mask_ori'])
sample['mask'] = F.decode_seg_mask(sample['mask_ori'])
if 'aux_masks' in sample:
sample['aux_masks'] = list(
map(decode_seg_mask, sample['aux_masks_ori']))
map(F.decode_seg_mask, sample['aux_masks_ori']))
return sample
@ -1987,18 +1981,21 @@ class MatchRadiance(Transform):
Args:
method (str, optional): Method used to match the radiance of the
bi-temporal images. Choices are {'hist', 'lsr'}. 'hist' stands
for histogram matching and 'lsr' stands for least-squares
regression. Default: 'hist'.
bi-temporal images. Choices are {'hist', 'lsr', 'fft}. 'hist'
stands for histogram matching, 'lsr' stands for least-squares
regression, and 'fft' replaces the low-frequency components of
the image to match the reference image. Default: 'hist'.
"""
def __init__(self, method='hist'):
super(MatchRadiance, self).__init__()
if method == 'hist':
self._match_func = match_histograms
self._match_func = F.match_histograms
elif method == 'lsr':
self._match_func = match_by_regression
self._match_func = F.match_by_regression
elif method == 'fft':
self._match_func = F.match_lf_components
else:
raise ValueError(
"{} is not a supported radiometric correction method.".format(

@ -390,6 +390,9 @@ class TestTransform(CpuCommonTest):
test_lsr = make_test_func(
T.MatchRadiance, 'lsr', _filter=_filter_only_mt)
test_lsr(self)
test_fft = make_test_func(
T.MatchRadiance, 'fft', _filter=_filter_only_mt)
test_fft(self)
class TestCompose(CpuCommonTest):

Loading…
Cancel
Save