|
|
|
@ -1223,16 +1223,13 @@ def classify_transforms( |
|
|
|
|
else: |
|
|
|
|
# Resize the shortest edge to matching target dim for non-square target |
|
|
|
|
tfl = [T.Resize(scale_size)] |
|
|
|
|
tfl += [T.CenterCrop(size)] |
|
|
|
|
|
|
|
|
|
tfl += [ |
|
|
|
|
T.ToTensor(), |
|
|
|
|
T.Normalize( |
|
|
|
|
mean=torch.tensor(mean), |
|
|
|
|
std=torch.tensor(std), |
|
|
|
|
), |
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
tfl.extend( |
|
|
|
|
[ |
|
|
|
|
T.CenterCrop(size), |
|
|
|
|
T.ToTensor(), |
|
|
|
|
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)), |
|
|
|
|
] |
|
|
|
|
) |
|
|
|
|
return T.Compose(tfl) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -1284,9 +1281,9 @@ def classify_augmentations( |
|
|
|
|
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range |
|
|
|
|
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] |
|
|
|
|
if hflip > 0.0: |
|
|
|
|
primary_tfl += [T.RandomHorizontalFlip(p=hflip)] |
|
|
|
|
primary_tfl.append(T.RandomHorizontalFlip(p=hflip)) |
|
|
|
|
if vflip > 0.0: |
|
|
|
|
primary_tfl += [T.RandomVerticalFlip(p=vflip)] |
|
|
|
|
primary_tfl.append(T.RandomVerticalFlip(p=vflip)) |
|
|
|
|
|
|
|
|
|
secondary_tfl = [] |
|
|
|
|
disable_color_jitter = False |
|
|
|
@ -1298,19 +1295,19 @@ def classify_augmentations( |
|
|
|
|
|
|
|
|
|
if auto_augment == "randaugment": |
|
|
|
|
if TORCHVISION_0_11: |
|
|
|
|
secondary_tfl += [T.RandAugment(interpolation=interpolation)] |
|
|
|
|
secondary_tfl.append(T.RandAugment(interpolation=interpolation)) |
|
|
|
|
else: |
|
|
|
|
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.') |
|
|
|
|
|
|
|
|
|
elif auto_augment == "augmix": |
|
|
|
|
if TORCHVISION_0_13: |
|
|
|
|
secondary_tfl += [T.AugMix(interpolation=interpolation)] |
|
|
|
|
secondary_tfl.append(T.AugMix(interpolation=interpolation)) |
|
|
|
|
else: |
|
|
|
|
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.') |
|
|
|
|
|
|
|
|
|
elif auto_augment == "autoaugment": |
|
|
|
|
if TORCHVISION_0_10: |
|
|
|
|
secondary_tfl += [T.AutoAugment(interpolation=interpolation)] |
|
|
|
|
secondary_tfl.append(T.AutoAugment(interpolation=interpolation)) |
|
|
|
|
else: |
|
|
|
|
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.') |
|
|
|
|
|
|
|
|
@ -1321,7 +1318,7 @@ def classify_augmentations( |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
if not disable_color_jitter: |
|
|
|
|
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)] |
|
|
|
|
secondary_tfl.append(T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)) |
|
|
|
|
|
|
|
|
|
final_tfl = [ |
|
|
|
|
T.ToTensor(), |
|
|
|
|