Fix torchvision InterpolationMode warnings (#14632)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
pull/14638/head
Glenn Jocher 7 months ago committed by GitHub
parent 0ec70b0054
commit 72466b9648
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 11
      ultralytics/data/augment.py

@ -2322,7 +2322,7 @@ def classify_transforms(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
interpolation=Image.BILINEAR,
interpolation="BILINEAR",
crop_fraction: float = DEFAULT_CROP_FRACTION,
):
"""
@ -2337,7 +2337,7 @@ def classify_transforms(
tuple, it defines (height, width).
mean (tuple): Mean values for each RGB channel used in normalization.
std (tuple): Standard deviation values for each RGB channel used in normalization.
interpolation (int): Interpolation method for resizing.
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
crop_fraction (float): Fraction of the image to be cropped.
Returns:
@ -2360,7 +2360,7 @@ def classify_transforms(
# Aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))]
else:
# Resize the shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
@ -2389,7 +2389,7 @@ def classify_augmentations(
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.0,
interpolation=Image.BILINEAR,
interpolation="BILINEAR",
):
"""
Creates a composition of image augmentation transforms for classification tasks.
@ -2411,7 +2411,7 @@ def classify_augmentations(
hsv_v (float): Image HSV-Value augmentation factor.
force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled.
erasing (float): Probability of random erasing.
interpolation (int): Interpolation method.
interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'.
Returns:
(torchvision.transforms.Compose): A composition of image augmentation transforms.
@ -2427,6 +2427,7 @@ def classify_augmentations(
raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)")
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
interpolation = getattr(T.InterpolationMode, interpolation)
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.0:
primary_tfl.append(T.RandomHorizontalFlip(p=hflip))

Loading…
Cancel
Save