|
|
|
@ -2322,7 +2322,7 @@ def classify_transforms( |
|
|
|
|
size=224, |
|
|
|
|
mean=DEFAULT_MEAN, |
|
|
|
|
std=DEFAULT_STD, |
|
|
|
|
interpolation=Image.BILINEAR, |
|
|
|
|
interpolation="BILINEAR", |
|
|
|
|
crop_fraction: float = DEFAULT_CROP_FRACTION, |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
@ -2337,7 +2337,7 @@ def classify_transforms( |
|
|
|
|
tuple, it defines (height, width). |
|
|
|
|
mean (tuple): Mean values for each RGB channel used in normalization. |
|
|
|
|
std (tuple): Standard deviation values for each RGB channel used in normalization. |
|
|
|
|
interpolation (int): Interpolation method for resizing. |
|
|
|
|
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. |
|
|
|
|
crop_fraction (float): Fraction of the image to be cropped. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
@ -2360,7 +2360,7 @@ def classify_transforms( |
|
|
|
|
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost |
|
|
|
|
if scale_size[0] == scale_size[1]: |
|
|
|
|
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg) |
|
|
|
|
tfl = [T.Resize(scale_size[0], interpolation=interpolation)] |
|
|
|
|
tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))] |
|
|
|
|
else: |
|
|
|
|
# Resize the shortest edge to matching target dim for non-square target |
|
|
|
|
tfl = [T.Resize(scale_size)] |
|
|
|
@ -2389,7 +2389,7 @@ def classify_augmentations( |
|
|
|
|
hsv_v=0.4, # image HSV-Value augmentation (fraction) |
|
|
|
|
force_color_jitter=False, |
|
|
|
|
erasing=0.0, |
|
|
|
|
interpolation=Image.BILINEAR, |
|
|
|
|
interpolation="BILINEAR", |
|
|
|
|
): |
|
|
|
|
""" |
|
|
|
|
Creates a composition of image augmentation transforms for classification tasks. |
|
|
|
@ -2411,7 +2411,7 @@ def classify_augmentations( |
|
|
|
|
hsv_v (float): Image HSV-Value augmentation factor. |
|
|
|
|
force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled. |
|
|
|
|
erasing (float): Probability of random erasing. |
|
|
|
|
interpolation (int): Interpolation method. |
|
|
|
|
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
(torchvision.transforms.Compose): A composition of image augmentation transforms. |
|
|
|
@ -2427,6 +2427,7 @@ def classify_augmentations( |
|
|
|
|
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") |
|
|
|
|
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range |
|
|
|
|
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range |
|
|
|
|
interpolation = getattr(T.InterpolationMode, interpolation) |
|
|
|
|
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] |
|
|
|
|
if hflip > 0.0: |
|
|
|
|
primary_tfl.append(T.RandomHorizontalFlip(p=hflip)) |
|
|
|
|