|
|
|
@ -30,39 +30,21 @@ from PIL import Image |
|
|
|
|
from joblib import load |
|
|
|
|
|
|
|
|
|
import paddlers |
|
|
|
|
from .functions import normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, \ |
|
|
|
|
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, vertical_flip_rle, crop_poly, \ |
|
|
|
|
crop_rle, expand_poly, expand_rle, resize_poly, resize_rle, dehaze, select_bands, \ |
|
|
|
|
to_intensity, to_uint8, img_flip, img_simple_rotate |
|
|
|
|
from .functions import ( |
|
|
|
|
normalize, horizontal_flip, permute, vertical_flip, center_crop, is_poly, |
|
|
|
|
horizontal_flip_poly, horizontal_flip_rle, vertical_flip_poly, |
|
|
|
|
vertical_flip_rle, crop_poly, crop_rle, expand_poly, expand_rle, |
|
|
|
|
resize_poly, resize_rle, dehaze, select_bands, to_intensity, to_uint8, |
|
|
|
|
img_flip, img_simple_rotate, decode_seg_mask) |
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
|
|
"Compose", |
|
|
|
|
"DecodeImg", |
|
|
|
|
"Resize", |
|
|
|
|
"RandomResize", |
|
|
|
|
"ResizeByShort", |
|
|
|
|
"RandomResizeByShort", |
|
|
|
|
"ResizeByLong", |
|
|
|
|
"RandomHorizontalFlip", |
|
|
|
|
"RandomVerticalFlip", |
|
|
|
|
"Normalize", |
|
|
|
|
"CenterCrop", |
|
|
|
|
"RandomCrop", |
|
|
|
|
"RandomScaleAspect", |
|
|
|
|
"RandomExpand", |
|
|
|
|
"Pad", |
|
|
|
|
"MixupImage", |
|
|
|
|
"RandomDistort", |
|
|
|
|
"RandomBlur", |
|
|
|
|
"RandomSwap", |
|
|
|
|
"Dehaze", |
|
|
|
|
"ReduceDim", |
|
|
|
|
"SelectBand", |
|
|
|
|
"ArrangeSegmenter", |
|
|
|
|
"ArrangeChangeDetector", |
|
|
|
|
"ArrangeClassifier", |
|
|
|
|
"ArrangeDetector", |
|
|
|
|
"RandomFlipOrRotate", |
|
|
|
|
"Compose", "DecodeImg", "Resize", "RandomResize", "ResizeByShort", |
|
|
|
|
"RandomResizeByShort", "ResizeByLong", "RandomHorizontalFlip", |
|
|
|
|
"RandomVerticalFlip", "Normalize", "CenterCrop", "RandomCrop", |
|
|
|
|
"RandomScaleAspect", "RandomExpand", "Pad", "MixupImage", "RandomDistort", |
|
|
|
|
"RandomBlur", "RandomSwap", "Dehaze", "ReduceDim", "SelectBand", |
|
|
|
|
"ArrangeSegmenter", "ArrangeChangeDetector", "ArrangeClassifier", |
|
|
|
|
"ArrangeDetector", "RandomFlipOrRotate", "ReloadMask" |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
interp_dict = { |
|
|
|
@ -74,6 +56,71 @@ interp_dict = { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Compose(object): |
|
|
|
|
""" |
|
|
|
|
Apply a series of data augmentation strategies to the input. |
|
|
|
|
All input images should be in Height-Width-Channel ([H, W, C]) format. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
transforms (list[paddlers.transforms.Transform]): List of data preprocess or |
|
|
|
|
augmentation operators. |
|
|
|
|
|
|
|
|
|
Raises: |
|
|
|
|
TypeError: Invalid type of transforms. |
|
|
|
|
ValueError: Invalid length of transforms. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, transforms): |
|
|
|
|
super(Compose, self).__init__() |
|
|
|
|
if not isinstance(transforms, list): |
|
|
|
|
raise TypeError( |
|
|
|
|
"Type of transforms is invalid. Must be a list, but received is {}." |
|
|
|
|
.format(type(transforms))) |
|
|
|
|
if len(transforms) < 1: |
|
|
|
|
raise ValueError( |
|
|
|
|
"Length of transforms must not be less than 1, but received is {}." |
|
|
|
|
.format(len(transforms))) |
|
|
|
|
transforms = copy.deepcopy(transforms) |
|
|
|
|
self.arrange = self._pick_arrange(transforms) |
|
|
|
|
self.transforms = transforms |
|
|
|
|
|
|
|
|
|
def __call__(self, sample): |
|
|
|
|
""" |
|
|
|
|
This is equivalent to sequentially calling compose_obj.apply_transforms() |
|
|
|
|
and compose_obj.arrange_outputs(). |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
sample = self.apply_transforms(sample) |
|
|
|
|
sample = self.arrange_outputs(sample) |
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
def apply_transforms(self, sample): |
|
|
|
|
for op in self.transforms: |
|
|
|
|
# Skip batch transforms amd mixup |
|
|
|
|
if isinstance(op, (paddlers.transforms.BatchRandomResize, |
|
|
|
|
paddlers.transforms.BatchRandomResizeByShort, |
|
|
|
|
MixupImage)): |
|
|
|
|
continue |
|
|
|
|
sample = op(sample) |
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
def arrange_outputs(self, sample): |
|
|
|
|
if self.arrange is not None: |
|
|
|
|
sample = self.arrange(sample) |
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
def _pick_arrange(self, transforms): |
|
|
|
|
arrange = None |
|
|
|
|
for idx, op in enumerate(transforms): |
|
|
|
|
if isinstance(op, Arrange): |
|
|
|
|
if idx != len(transforms) - 1: |
|
|
|
|
raise ValueError( |
|
|
|
|
"Arrange operator must be placed at the end of the list." |
|
|
|
|
) |
|
|
|
|
arrange = transforms.pop(idx) |
|
|
|
|
return arrange |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Transform(object): |
|
|
|
|
""" |
|
|
|
|
Parent class of all data augmentation operations |
|
|
|
@ -178,14 +225,14 @@ class DecodeImg(Transform): |
|
|
|
|
elif ext == '.npy': |
|
|
|
|
return np.load(img_path) |
|
|
|
|
else: |
|
|
|
|
raise TypeError('Image format {} is not supported!'.format(ext)) |
|
|
|
|
raise TypeError("Image format {} is not supported!".format(ext)) |
|
|
|
|
|
|
|
|
|
def apply_im(self, im_path): |
|
|
|
|
if isinstance(im_path, str): |
|
|
|
|
try: |
|
|
|
|
image = self.read_img(im_path) |
|
|
|
|
except: |
|
|
|
|
raise ValueError('Cannot read the image file {}!'.format( |
|
|
|
|
raise ValueError("Cannot read the image file {}!".format( |
|
|
|
|
im_path)) |
|
|
|
|
else: |
|
|
|
|
image = im_path |
|
|
|
@ -217,7 +264,9 @@ class DecodeImg(Transform): |
|
|
|
|
Returns: |
|
|
|
|
dict: Decoded sample. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
if 'image' in sample: |
|
|
|
|
sample['image_ori'] = copy.deepcopy(sample['image']) |
|
|
|
|
sample['image'] = self.apply_im(sample['image']) |
|
|
|
|
if 'image2' in sample: |
|
|
|
|
sample['image2'] = self.apply_im(sample['image2']) |
|
|
|
@ -227,6 +276,7 @@ class DecodeImg(Transform): |
|
|
|
|
sample['image'] = self.apply_im(sample['image_t1']) |
|
|
|
|
sample['image2'] = self.apply_im(sample['image_t2']) |
|
|
|
|
if 'mask' in sample: |
|
|
|
|
sample['mask_ori'] = copy.deepcopy(sample['mask']) |
|
|
|
|
sample['mask'] = self.apply_mask(sample['mask']) |
|
|
|
|
im_height, im_width, _ = sample['image'].shape |
|
|
|
|
se_height, se_width = sample['mask'].shape |
|
|
|
@ -234,6 +284,7 @@ class DecodeImg(Transform): |
|
|
|
|
raise ValueError( |
|
|
|
|
"The height or width of the image is not same as the mask.") |
|
|
|
|
if 'aux_masks' in sample: |
|
|
|
|
sample['aux_masks_ori'] = copy.deepcopy(sample['aux_masks_ori']) |
|
|
|
|
sample['aux_masks'] = list( |
|
|
|
|
map(self.apply_mask, sample['aux_masks'])) |
|
|
|
|
# TODO: check the shape of auxiliary masks |
|
|
|
@ -244,61 +295,6 @@ class DecodeImg(Transform): |
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Compose(Transform): |
|
|
|
|
""" |
|
|
|
|
Apply a series of data augmentation to the input. |
|
|
|
|
All input images are in Height-Width-Channel ([H, W, C]) format. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
transforms (list[paddlers.transforms.Transform]): List of data preprocess or augmentations. |
|
|
|
|
Raises: |
|
|
|
|
TypeError: Invalid type of transforms. |
|
|
|
|
ValueError: Invalid length of transforms. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, transforms, to_uint8=True): |
|
|
|
|
super(Compose, self).__init__() |
|
|
|
|
if not isinstance(transforms, list): |
|
|
|
|
raise TypeError( |
|
|
|
|
'Type of transforms is invalid. Must be a list, but received is {}' |
|
|
|
|
.format(type(transforms))) |
|
|
|
|
if len(transforms) < 1: |
|
|
|
|
raise ValueError( |
|
|
|
|
'Length of transforms must not be less than 1, but received is {}' |
|
|
|
|
.format(len(transforms))) |
|
|
|
|
self.transforms = transforms |
|
|
|
|
self.decode_image = DecodeImg(to_uint8=to_uint8) |
|
|
|
|
self.arrange_outputs = None |
|
|
|
|
self.apply_im_only = False |
|
|
|
|
|
|
|
|
|
def __call__(self, sample): |
|
|
|
|
if self.apply_im_only: |
|
|
|
|
if 'mask' in sample: |
|
|
|
|
mask_backup = copy.deepcopy(sample['mask']) |
|
|
|
|
del sample['mask'] |
|
|
|
|
if 'aux_masks' in sample: |
|
|
|
|
aux_masks = copy.deepcopy(sample['aux_masks']) |
|
|
|
|
|
|
|
|
|
sample = self.decode_image(sample) |
|
|
|
|
|
|
|
|
|
for op in self.transforms: |
|
|
|
|
# skip batch transforms amd mixup |
|
|
|
|
if isinstance(op, (paddlers.transforms.BatchRandomResize, |
|
|
|
|
paddlers.transforms.BatchRandomResizeByShort, |
|
|
|
|
MixupImage)): |
|
|
|
|
continue |
|
|
|
|
sample = op(sample) |
|
|
|
|
|
|
|
|
|
if self.arrange_outputs is not None: |
|
|
|
|
if self.apply_im_only: |
|
|
|
|
sample['mask'] = mask_backup |
|
|
|
|
if 'aux_masks' in locals(): |
|
|
|
|
sample['aux_masks'] = aux_masks |
|
|
|
|
sample = self.arrange_outputs(sample) |
|
|
|
|
|
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Resize(Transform): |
|
|
|
|
""" |
|
|
|
|
Resize input. |
|
|
|
@ -323,7 +319,7 @@ class Resize(Transform): |
|
|
|
|
def __init__(self, target_size, interp='LINEAR', keep_ratio=False): |
|
|
|
|
super(Resize, self).__init__() |
|
|
|
|
if not (interp == "RANDOM" or interp in interp_dict): |
|
|
|
|
raise ValueError("interp should be one of {}".format( |
|
|
|
|
raise ValueError("`interp` should be one of {}.".format( |
|
|
|
|
interp_dict.keys())) |
|
|
|
|
if isinstance(target_size, int): |
|
|
|
|
target_size = (target_size, target_size) |
|
|
|
@ -331,7 +327,7 @@ class Resize(Transform): |
|
|
|
|
if not (isinstance(target_size, |
|
|
|
|
(list, tuple)) and len(target_size) == 2): |
|
|
|
|
raise TypeError( |
|
|
|
|
"target_size should be an int or a list of length 2, but received {}". |
|
|
|
|
"`target_size` should be an int or a list of length 2, but received {}.". |
|
|
|
|
format(target_size)) |
|
|
|
|
# (height, width) |
|
|
|
|
self.target_size = target_size |
|
|
|
@ -443,11 +439,11 @@ class RandomResize(Transform): |
|
|
|
|
def __init__(self, target_sizes, interp='LINEAR'): |
|
|
|
|
super(RandomResize, self).__init__() |
|
|
|
|
if not (interp == "RANDOM" or interp in interp_dict): |
|
|
|
|
raise ValueError("interp should be one of {}".format( |
|
|
|
|
raise ValueError("`interp` should be one of {}.".format( |
|
|
|
|
interp_dict.keys())) |
|
|
|
|
self.interp = interp |
|
|
|
|
assert isinstance(target_sizes, list), \ |
|
|
|
|
"target_size must be a list." |
|
|
|
|
"`target_size` must be a list." |
|
|
|
|
for i, item in enumerate(target_sizes): |
|
|
|
|
if isinstance(item, int): |
|
|
|
|
target_sizes[i] = (item, item) |
|
|
|
@ -478,7 +474,7 @@ class ResizeByShort(Transform): |
|
|
|
|
|
|
|
|
|
def __init__(self, short_size=256, max_size=-1, interp='LINEAR'): |
|
|
|
|
if not (interp == "RANDOM" or interp in interp_dict): |
|
|
|
|
raise ValueError("interp should be one of {}".format( |
|
|
|
|
raise ValueError("`interp` should be one of {}".format( |
|
|
|
|
interp_dict.keys())) |
|
|
|
|
super(ResizeByShort, self).__init__() |
|
|
|
|
self.short_size = short_size |
|
|
|
@ -522,11 +518,11 @@ class RandomResizeByShort(Transform): |
|
|
|
|
def __init__(self, short_sizes, max_size=-1, interp='LINEAR'): |
|
|
|
|
super(RandomResizeByShort, self).__init__() |
|
|
|
|
if not (interp == "RANDOM" or interp in interp_dict): |
|
|
|
|
raise ValueError("interp should be one of {}".format( |
|
|
|
|
raise ValueError("`interp` should be one of {}".format( |
|
|
|
|
interp_dict.keys())) |
|
|
|
|
self.interp = interp |
|
|
|
|
assert isinstance(short_sizes, list), \ |
|
|
|
|
"short_sizes must be a list." |
|
|
|
|
"`short_sizes` must be a list." |
|
|
|
|
|
|
|
|
|
self.short_sizes = short_sizes |
|
|
|
|
self.max_size = max_size |
|
|
|
@ -574,6 +570,7 @@ class RandomFlipOrRotate(Transform): |
|
|
|
|
|
|
|
|
|
# 定义数据增强 |
|
|
|
|
train_transforms = T.Compose([ |
|
|
|
|
T.DecodeImg(), |
|
|
|
|
T.RandomFlipOrRotate( |
|
|
|
|
probs = [0.3, 0.2] # 进行flip增强的概率是0.3,进行rotate增强的概率是0.2,不变的概率是0.5 |
|
|
|
|
probsf = [0.3, 0.25, 0, 0, 0] # flip增强时,使用水平flip、垂直flip的概率分别是0.3、0.25,水平且垂直flip、对角线flip、反对角线flip概率均为0,不变的概率是0.45 |
|
|
|
@ -609,12 +606,12 @@ class RandomFlipOrRotate(Transform): |
|
|
|
|
|
|
|
|
|
def apply_bbox(self, bbox, mode_id, flip_mode=True): |
|
|
|
|
raise TypeError( |
|
|
|
|
"Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks." |
|
|
|
|
"Currently, RandomFlipOrRotate is not available for object detection tasks." |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def apply_segm(self, bbox, mode_id, flip_mode=True): |
|
|
|
|
raise TypeError( |
|
|
|
|
"Currently, `paddlers.transforms.RandomFlipOrRotate` is not available for object detection tasks." |
|
|
|
|
"Currently, RandomFlipOrRotate is not available for object detection tasks." |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def get_probs_range(self, probs): |
|
|
|
@ -845,11 +842,11 @@ class Normalize(Transform): |
|
|
|
|
from functools import reduce |
|
|
|
|
if reduce(lambda x, y: x * y, std) == 0: |
|
|
|
|
raise ValueError( |
|
|
|
|
'Std should not contain 0, but received is {}.'.format(std)) |
|
|
|
|
"`std` should not contain 0, but received is {}.".format(std)) |
|
|
|
|
if reduce(lambda x, y: x * y, |
|
|
|
|
[a - b for a, b in zip(max_val, min_val)]) == 0: |
|
|
|
|
raise ValueError( |
|
|
|
|
'(max_val - min_val) should not contain 0, but received is {}.'. |
|
|
|
|
"(`max_val` - `min_val`) should not contain 0, but received is {}.". |
|
|
|
|
format((np.asarray(max_val) - np.asarray(min_val)).tolist())) |
|
|
|
|
|
|
|
|
|
self.mean = mean |
|
|
|
@ -1153,11 +1150,11 @@ class RandomExpand(Transform): |
|
|
|
|
im_padding_value=127.5, |
|
|
|
|
label_padding_value=255): |
|
|
|
|
super(RandomExpand, self).__init__() |
|
|
|
|
assert upper_ratio > 1.01, "expand ratio must be larger than 1.01" |
|
|
|
|
assert upper_ratio > 1.01, "`upper_ratio` must be larger than 1.01." |
|
|
|
|
self.upper_ratio = upper_ratio |
|
|
|
|
self.prob = prob |
|
|
|
|
assert isinstance(im_padding_value, (Number, Sequence)), \ |
|
|
|
|
"fill value must be either float or sequence" |
|
|
|
|
"Value to fill must be either float or sequence." |
|
|
|
|
self.im_padding_value = im_padding_value |
|
|
|
|
self.label_padding_value = label_padding_value |
|
|
|
|
|
|
|
|
@ -1204,16 +1201,16 @@ class Pad(Transform): |
|
|
|
|
if isinstance(target_size, (list, tuple)): |
|
|
|
|
if len(target_size) != 2: |
|
|
|
|
raise ValueError( |
|
|
|
|
'`target_size` should include 2 elements, but it is {}'. |
|
|
|
|
"`target_size` should contain 2 elements, but it is {}.". |
|
|
|
|
format(target_size)) |
|
|
|
|
if isinstance(target_size, int): |
|
|
|
|
target_size = [target_size] * 2 |
|
|
|
|
|
|
|
|
|
assert pad_mode in [ |
|
|
|
|
-1, 0, 1, 2 |
|
|
|
|
], 'currently only supports four modes [-1, 0, 1, 2]' |
|
|
|
|
], "Currently only four modes are supported: [-1, 0, 1, 2]." |
|
|
|
|
if pad_mode == -1: |
|
|
|
|
assert offsets, 'if pad_mode is -1, offsets should not be None' |
|
|
|
|
assert offsets, "if `pad_mode` is -1, `offsets` should not be None." |
|
|
|
|
|
|
|
|
|
self.target_size = target_size |
|
|
|
|
self.size_divisor = size_divisor |
|
|
|
@ -1314,9 +1311,9 @@ class MixupImage(Transform): |
|
|
|
|
""" |
|
|
|
|
super(MixupImage, self).__init__() |
|
|
|
|
if alpha <= 0.0: |
|
|
|
|
raise ValueError("alpha should be positive in {}".format(self)) |
|
|
|
|
raise ValueError("`alpha` should be positive in MixupImage.") |
|
|
|
|
if beta <= 0.0: |
|
|
|
|
raise ValueError("beta should be positive in {}".format(self)) |
|
|
|
|
raise ValueError("`beta` should be positive in MixupImage.") |
|
|
|
|
self.alpha = alpha |
|
|
|
|
self.beta = beta |
|
|
|
|
self.mixup_epoch = mixup_epoch |
|
|
|
@ -1753,55 +1750,56 @@ class RandomSwap(Transform): |
|
|
|
|
|
|
|
|
|
def apply(self, sample): |
|
|
|
|
if 'image2' not in sample: |
|
|
|
|
raise ValueError('image2 is not found in the sample.') |
|
|
|
|
raise ValueError("'image2' is not found in the sample.") |
|
|
|
|
if random.random() < self.prob: |
|
|
|
|
sample['image'], sample['image2'] = sample['image2'], sample[ |
|
|
|
|
'image'] |
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArrangeSegmenter(Transform): |
|
|
|
|
class ReloadMask(Transform): |
|
|
|
|
def apply(self, sample): |
|
|
|
|
sample['mask'] = decode_seg_mask(sample['mask_ori']) |
|
|
|
|
if 'aux_masks' in sample: |
|
|
|
|
sample['aux_masks'] = list( |
|
|
|
|
map(decode_seg_mask, sample['aux_masks_ori'])) |
|
|
|
|
return sample |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Arrange(Transform): |
|
|
|
|
def __init__(self, mode): |
|
|
|
|
super(ArrangeSegmenter, self).__init__() |
|
|
|
|
super().__init__() |
|
|
|
|
if mode not in ['train', 'eval', 'test', 'quant']: |
|
|
|
|
raise ValueError( |
|
|
|
|
"mode should be defined as one of ['train', 'eval', 'test', 'quant']!" |
|
|
|
|
"`mode` should be defined as one of ['train', 'eval', 'test', 'quant']!" |
|
|
|
|
) |
|
|
|
|
self.mode = mode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArrangeSegmenter(Arrange): |
|
|
|
|
def apply(self, sample): |
|
|
|
|
if 'mask' in sample: |
|
|
|
|
mask = sample['mask'] |
|
|
|
|
mask = mask.astype('int64') |
|
|
|
|
|
|
|
|
|
image = permute(sample['image'], False) |
|
|
|
|
if self.mode == 'train': |
|
|
|
|
mask = mask.astype('int64') |
|
|
|
|
return image, mask |
|
|
|
|
if self.mode == 'eval': |
|
|
|
|
mask = np.asarray(Image.open(mask)) |
|
|
|
|
mask = mask[np.newaxis, :, :].astype('int64') |
|
|
|
|
return image, mask |
|
|
|
|
if self.mode == 'test': |
|
|
|
|
return image, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArrangeChangeDetector(Transform): |
|
|
|
|
def __init__(self, mode): |
|
|
|
|
super(ArrangeChangeDetector, self).__init__() |
|
|
|
|
if mode not in ['train', 'eval', 'test', 'quant']: |
|
|
|
|
raise ValueError( |
|
|
|
|
"mode should be defined as one of ['train', 'eval', 'test', 'quant']!" |
|
|
|
|
) |
|
|
|
|
self.mode = mode |
|
|
|
|
|
|
|
|
|
class ArrangeChangeDetector(Arrange): |
|
|
|
|
def apply(self, sample): |
|
|
|
|
if 'mask' in sample: |
|
|
|
|
mask = sample['mask'] |
|
|
|
|
mask = mask.astype('int64') |
|
|
|
|
|
|
|
|
|
image_t1 = permute(sample['image'], False) |
|
|
|
|
image_t2 = permute(sample['image2'], False) |
|
|
|
|
if self.mode == 'train': |
|
|
|
|
mask = mask.astype('int64') |
|
|
|
|
masks = [mask] |
|
|
|
|
if 'aux_masks' in sample: |
|
|
|
|
masks.extend( |
|
|
|
@ -1810,22 +1808,12 @@ class ArrangeChangeDetector(Transform): |
|
|
|
|
image_t1, |
|
|
|
|
image_t2, ) + tuple(masks) |
|
|
|
|
if self.mode == 'eval': |
|
|
|
|
mask = np.asarray(Image.open(mask)) |
|
|
|
|
mask = mask[np.newaxis, :, :].astype('int64') |
|
|
|
|
return image_t1, image_t2, mask |
|
|
|
|
if self.mode == 'test': |
|
|
|
|
return image_t1, image_t2, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArrangeClassifier(Transform): |
|
|
|
|
def __init__(self, mode): |
|
|
|
|
super(ArrangeClassifier, self).__init__() |
|
|
|
|
if mode not in ['train', 'eval', 'test', 'quant']: |
|
|
|
|
raise ValueError( |
|
|
|
|
"mode should be defined as one of ['train', 'eval', 'test', 'quant']!" |
|
|
|
|
) |
|
|
|
|
self.mode = mode |
|
|
|
|
|
|
|
|
|
class ArrangeClassifier(Arrange): |
|
|
|
|
def apply(self, sample): |
|
|
|
|
image = permute(sample['image'], False) |
|
|
|
|
if self.mode in ['train', 'eval']: |
|
|
|
@ -1834,15 +1822,7 @@ class ArrangeClassifier(Transform): |
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ArrangeDetector(Transform): |
|
|
|
|
def __init__(self, mode): |
|
|
|
|
super(ArrangeDetector, self).__init__() |
|
|
|
|
if mode not in ['train', 'eval', 'test', 'quant']: |
|
|
|
|
raise ValueError( |
|
|
|
|
"mode should be defined as one of ['train', 'eval', 'test', 'quant']!" |
|
|
|
|
) |
|
|
|
|
self.mode = mode |
|
|
|
|
|
|
|
|
|
class ArrangeDetector(Arrange): |
|
|
|
|
def apply(self, sample): |
|
|
|
|
if self.mode == 'eval' and 'gt_poly' in sample: |
|
|
|
|
del sample['gt_poly'] |
|
|
|
|