[Feature] Update transform (#28)

* [Feature] Update transform

* [Fix][Transform] fix len of band_list is zero
main
Yizhou Chen 3 years ago committed by GitHub
parent 35f51f1ae8
commit 58cacc28d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 39
      paddlers/transforms/functions.py
  2. 155
      paddlers/transforms/operators.py

@ -206,6 +206,7 @@ def to_uint8(im):
Returns:
np.ndarray: Image on uint8.
"""
# 2% linear stretch
def _two_percentLinear(image, max_out=255, min_out=0):
def _gray_process(gray, maxout=max_out, minout=min_out):
@ -216,6 +217,7 @@ def to_uint8(im):
processed_gray = ((truncated_gray - low_value) / (high_value - low_value)) * \
(maxout - minout)
return processed_gray
if len(image.shape) == 3:
processes = []
for b in range(image.shape[-1]):
@ -244,7 +246,7 @@ def to_uint8(im):
lut = []
for bt in range(0, len(hist), NUMS):
# step size
step = reduce(operator.add, hist[bt : bt + NUMS]) / (NUMS - 1)
step = reduce(operator.add, hist[bt:bt + NUMS]) / (NUMS - 1)
# create balanced lookup table
n = 0
for i in range(NUMS):
@ -301,14 +303,18 @@ def select_bands(im, band_list=[1, 2, 3]):
Returns:
np.ndarray: The image after band selected.
"""
if len(im.shape) == 2: # just have one channel
return im
if not isinstance(band_list, list) or len(band_list) == 0:
raise TypeError("band_list must be non empty list.")
total_band = im.shape[-1]
result = []
for band in band_list:
band = int(band - 1)
if band < 0 or band >= total_band:
raise ValueError(
"The element in band_list must > 1 and <= {}.".format(str(total_band)))
result.append()
raise ValueError("The element in band_list must > 1 and <= {}.".
format(str(total_band)))
result.append(im[:, :, band])
ima = np.stack(result, axis=0)
return ima
@ -323,6 +329,7 @@ def de_haze(im, gamma=False):
Returns:
np.ndarray: The image after defogged.
"""
def _guided_filter(I, p, r, eps):
m_I = cv2.boxFilter(I, -1, (r, r))
m_p = cv2.boxFilter(p, -1, (r, r))
@ -350,16 +357,17 @@ def de_haze(im, gamma=False):
atmo_illum = np.mean(im, 2)[atmo_mask >= ht[1][lmax]].max()
atmo_mask = np.minimum(atmo_mask * w, maxatmo_mask)
return atmo_mask, atmo_illum
if np.max(im) > 1:
im = im / 255.
result = np.zeros(im.shape)
mask_img, atmo_illum = _de_fog(im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
mask_img, atmo_illum = _de_fog(
im, r=81, w=0.95, maxatmo_mask=0.80, eps=1e-8)
for k in range(3):
result[:, :, k] = (im[:, :, k] - mask_img) / (1 - mask_img / atmo_illum)
result = np.clip(result, 0, 1)
if gamma:
result = result ** (np.log(0.5) / np.log(result.mean()))
result = result**(np.log(0.5) / np.log(result.mean()))
return (result * 255).astype("uint8")
@ -398,7 +406,8 @@ def match_histograms(im, ref):
ValueError: When the number of channels of `ref` differs from that of im`.
"""
# TODO: Check the data types of the inputs to see if they are supported by skimage
return exposure.match_histograms(im, ref, channel_axis=-1 if im.ndim>2 else None)
return exposure.match_histograms(
im, ref, channel_axis=-1 if im.ndim > 2 else None)
def match_by_regression(im, ref, pif_loc=None):
@ -418,27 +427,29 @@ def match_by_regression(im, ref, pif_loc=None):
Raises:
ValueError: When the shape of `ref` differs from that of `im`.
"""
def _linear_regress(im, ref, loc):
regressor = LinearRegression()
if loc is not None:
x, y = im[loc], ref[loc]
else:
x, y = im, ref
x, y = x.reshape(-1,1), y.ravel()
x, y = x.reshape(-1, 1), y.ravel()
regressor.fit(x, y)
matched = regressor.predict(im.reshape(-1,1))
matched = regressor.predict(im.reshape(-1, 1))
return matched.reshape(im.shape)
if im.shape != ref.shape:
raise ValueError("Image and Reference must have the same shape!")
raise ValueError("Image and Reference must have the same shape!")
if im.ndim > 2:
# Multiple channels
matched = np.empty(im.shape, dtype=im.dtype)
for ch in range(im.shape[-1]):
matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch], pif_loc)
matched[..., ch] = _linear_regress(im[..., ch], ref[..., ch],
pif_loc)
else:
# Single channel
matched = _linear_regress(im, ref, pif_loc).astype(im.dtype)
return matched
return matched

@ -31,17 +31,16 @@ import paddlers
from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \
crop_rle, expand_poly, expand_rle, resize_poly, resize_rle
crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, de_haze, pca, select_bands, \
to_intensity, to_uint8
__all__ = [
"Compose", "ImgDecoder", "Resize", "RandomResize", "ResizeByShort",
"RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip",
"RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop",
"RandomScaleAspect", "RandomExpand", "Padding", "MixupImage",
"RandomDistort", "RandomBlur",
"RandomSwap",
"ArrangeSegmenter", "ArrangeChangeDetector",
"RandomDistort", "RandomBlur", "RandomSwap", "Defogging", "DimReducing",
"BandSelecting", "ArrangeSegmenter", "ArrangeChangeDetector",
"ArrangeClassifier", "ArrangeDetector"
]
@ -85,7 +84,8 @@ class Transform(object):
if 'gt_bbox' in sample:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'])
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
sample['aux_masks'] = list(
map(self.apply_mask, sample['aux_masks']))
return sample
@ -105,9 +105,10 @@ class ImgDecoder(Transform):
to_rgb (bool, optional): If True, convert input images from BGR format to RGB format. Defaults to True.
"""
def __init__(self, to_rgb=True):
def __init__(self, to_rgb=True, to_uint8=True):
super(ImgDecoder, self).__init__()
self.to_rgb = to_rgb
self.to_uint8 = to_uint8
def read_img(self, img_path, input_channel=3):
img_format = imghdr.what(img_path)
@ -129,6 +130,7 @@ class ImgDecoder(Transform):
raise Exception('Can not open', img_path)
im_data = dataset.ReadAsArray()
if im_data.ndim == 2:
im_data = to_intensity(im_data) # is read SAR
im_data = im_data[:, :, np.newaxis]
elif im_data.ndim == 3:
im_data = im_data.transpose((1, 2, 0))
@ -158,6 +160,9 @@ class ImgDecoder(Transform):
if self.to_rgb and image.shape[-1] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.to_uint8:
image = to_uint8(image)
return image
def apply_mask(self, mask):
@ -191,7 +196,8 @@ class ImgDecoder(Transform):
raise Exception(
"The height or width of the im is not same as the mask")
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
sample['aux_masks'] = list(
map(self.apply_mask, sample['aux_masks']))
# TODO: check the shape of auxiliary masks
sample['im_shape'] = np.array(
@ -350,12 +356,16 @@ class Resize(Transform):
sample['image'] = self.apply_im(sample['image'], interp, target_size)
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], interp, target_size)
sample['image2'] = self.apply_im(sample['image2'], interp,
target_size)
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'], target_size)
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(partial(self.apply_mask, target_size=target_size), sample['aux_masks']))
sample['aux_masks'] = list(
map(partial(
self.apply_mask, target_size=target_size),
sample['aux_masks']))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(
sample['gt_bbox'], [im_scale_x, im_scale_y], target_size)
@ -557,7 +567,8 @@ class RandomHorizontalFlip(Transform):
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
sample['aux_masks'] = list(
map(self.apply_mask, sample['aux_masks']))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_w)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@ -614,7 +625,8 @@ class RandomVerticalFlip(Transform):
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
sample['aux_masks'] = list(
map(self.apply_mask, sample['aux_masks']))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], im_h)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@ -653,8 +665,8 @@ class Normalize(Transform):
from functools import reduce
if reduce(lambda x, y: x * y, std) == 0:
raise ValueError(
'Std should not have 0, but received is {}'.format(std))
raise ValueError('Std should not have 0, but received is {}'.format(
std))
if is_scale:
if reduce(lambda x, y: x * y,
[a - b for a, b in zip(max_val, min_val)]) == 0:
@ -679,7 +691,7 @@ class Normalize(Transform):
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
sample['image2'] = self.apply_im(sample['image2'])
return sample
@ -710,11 +722,12 @@ class CenterCrop(Transform):
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
sample['image2'] = self.apply_im(sample['image2'])
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'])
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(self.apply_mask, sample['aux_masks']))
sample['aux_masks'] = list(
map(self.apply_mask, sample['aux_masks']))
return sample
@ -779,8 +792,7 @@ class RandomCrop(Transform):
if self.cover_all_box and iou.min() < thresh:
continue
cropped_box, valid_ids = self._crop_box_with_center_constraint(
sample['gt_bbox'],
np.array(
sample['gt_bbox'], np.array(
crop_box, dtype=np.float32))
if valid_ids.size > 0:
return crop_box, cropped_box, valid_ids
@ -907,7 +919,10 @@ class RandomCrop(Transform):
sample['mask'] = self.apply_mask(sample['mask'], crop_box)
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(partial(self.apply_mask, crop=crop_box), sample['aux_masks']))
sample['aux_masks'] = list(
map(partial(
self.apply_mask, crop=crop_box),
sample['aux_masks']))
if self.crop_size is not None:
sample = Resize(self.crop_size)(sample)
@ -1095,11 +1110,14 @@ class Padding(Transform):
sample['image'] = self.apply_im(sample['image'], offsets, (h, w))
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
sample['image2'] = self.apply_im(sample['image2'], offsets, (h, w))
if 'mask' in sample:
sample['mask'] = self.apply_mask(sample['mask'], offsets, (h, w))
if 'aux_masks' in sample:
sample['aux_masks'] = list(map(partial(self.apply_mask, offsets=offsets, target_size=(h,w)), sample['aux_masks']))
sample['aux_masks'] = list(
map(partial(
self.apply_mask, offsets=offsets, target_size=(h, w)),
sample['aux_masks']))
if 'gt_bbox' in sample and len(sample['gt_bbox']) > 0:
sample['gt_bbox'] = self.apply_bbox(sample['gt_bbox'], offsets)
if 'gt_poly' in sample and len(sample['gt_poly']) > 0:
@ -1251,7 +1269,7 @@ class RandomDistort(Transform):
res_list = []
channel = image.shape[2]
for i in range(channel // 3):
sub_img = image[:, :, 3*i : 3*(i+1)]
sub_img = image[:, :, 3 * i:3 * (i + 1)]
sub_img = sub_img.astype(np.float32)
sub_img = np.dot(image, t)
res_list.append(sub_img)
@ -1271,10 +1289,11 @@ class RandomDistort(Transform):
res_list = []
channel = image.shape[2]
for i in range(channel // 3):
sub_img = image[:, :, 3*i : 3*(i+1)]
sub_img = image[:, :, 3 * i:3 * (i + 1)]
sub_img = sub_img.astype(np.float32)
# it works, but result differ from HSV version
gray = sub_img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = sub_img * np.array(
[[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta)
sub_img *= delta
@ -1340,7 +1359,8 @@ class RandomDistort(Transform):
if np.random.randint(0, 2):
sample['image'] = sample['image'][..., np.random.permutation(3)]
if 'image2' in sample:
sample['image2'] = sample['image2'][..., np.random.permutation(3)]
sample['image2'] = sample['image2'][
..., np.random.permutation(3)]
return sample
@ -1380,6 +1400,77 @@ class RandomBlur(Transform):
return sample
class Defogging(Transform):
"""
Defog input image(s).
Args:
gamma (bool, optional): Use gamma correction or not. Defaults to False.
"""
def __init__(self, gamma=False):
super(Defogging, self).__init__()
self.gamma = gamma
def apply_im(self, image):
image = de_haze(image, self.gamma)
return image
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
return sample
class DimReducing(Transform):
"""
Use PCA to reduce input image(s) dimension.
Args:
dim (int, optional): Reserved dimensions. Defaults to 3.
whiten (bool, optional): PCA whiten or not. Defaults to True.
"""
def __init__(self, dim=3, whiten=True):
super(DimReducing, self).__init__()
self.dim = dim
self.whiten = whiten
def apply_im(self, image):
image = pca(image, self.gamma)
return image
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
return sample
class BandSelecting(Transform):
"""
Select the band of the input image(s).
Args:
band_list (list, optional): Bands of selected (Start with 1). Defaults to [1, 2, 3].
"""
def __init__(self, band_list=[1, 2, 3]):
super(BandSelecting, self).__init__()
self.band_list = band_list
def apply_im(self, image):
image = select_bands(image, self.band_list)
return image
def apply(self, sample):
sample['image'] = self.apply_im(sample['image'])
if 'image2' in sample:
sample['image2'] = self.apply_im(sample['image2'])
return sample
class _PadBox(Transform):
def __init__(self, num_max_boxes=50):
"""
@ -1464,7 +1555,7 @@ class _Permute(Transform):
if 'image2' in sample:
sample['image2'] = permute(sample['image2'], False)
return sample
class RandomSwap(Transform):
"""
@ -1482,7 +1573,8 @@ class RandomSwap(Transform):
if 'image2' not in sample:
raise ValueError('image2 is not found in the sample.')
if random.random() < self.prob:
sample['image'], sample['image2'] = sample['image2'], sample['image']
sample['image'], sample['image2'] = sample['image2'], sample[
'image']
return sample
@ -1530,8 +1622,11 @@ class ArrangeChangeDetector(Transform):
mask = mask.astype('int64')
masks = [mask]
if 'aux_masks' in sample:
masks.extend(map(methodcaller('astype', 'int64'), sample['aux_masks']))
return (image_t1, image_t2,) + tuple(masks)
masks.extend(
map(methodcaller('astype', 'int64'), sample['aux_masks']))
return (
image_t1,
image_t2, ) + tuple(masks)
if self.mode == 'eval':
mask = np.asarray(Image.open(mask))
mask = mask[np.newaxis, :, :].astype('int64')

Loading…
Cancel
Save