diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py index 90cdd2313..cd084f3e6 100644 --- a/ultralytics/data/augment.py +++ b/ultralytics/data/augment.py @@ -2322,7 +2322,7 @@ def classify_transforms( size=224, mean=DEFAULT_MEAN, std=DEFAULT_STD, - interpolation=Image.BILINEAR, + interpolation="BILINEAR", crop_fraction: float = DEFAULT_CROP_FRACTION, ): """ @@ -2337,7 +2337,7 @@ def classify_transforms( tuple, it defines (height, width). mean (tuple): Mean values for each RGB channel used in normalization. std (tuple): Standard deviation values for each RGB channel used in normalization. - interpolation (int): Interpolation method for resizing. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. crop_fraction (float): Fraction of the image to be cropped. Returns: @@ -2360,7 +2360,7 @@ def classify_transforms( # Aspect ratio is preserved, crops center within image, no borders are added, image is lost if scale_size[0] == scale_size[1]: # Simple case, use torchvision built-in Resize with the shortest edge mode (scalar size arg) - tfl = [T.Resize(scale_size[0], interpolation=interpolation)] + tfl = [T.Resize(scale_size[0], interpolation=getattr(T.InterpolationMode, interpolation))] else: # Resize the shortest edge to matching target dim for non-square target tfl = [T.Resize(scale_size)] @@ -2389,7 +2389,7 @@ def classify_augmentations( hsv_v=0.4, # image HSV-Value augmentation (fraction) force_color_jitter=False, erasing=0.0, - interpolation=Image.BILINEAR, + interpolation="BILINEAR", ): """ Creates a composition of image augmentation transforms for classification tasks. @@ -2411,7 +2411,7 @@ def classify_augmentations( hsv_v (float): Image HSV-Value augmentation factor. force_color_jitter (bool): Whether to apply color jitter even if auto augment is enabled. erasing (float): Probability of random erasing. - interpolation (int): Interpolation method. + interpolation (str): Interpolation method of either 'NEAREST', 'BILINEAR' or 'BICUBIC'. Returns: (torchvision.transforms.Compose): A composition of image augmentation transforms. @@ -2427,6 +2427,7 @@ def classify_augmentations( raise TypeError(f"classify_transforms() size {size} must be integer, not (list, tuple)") scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range + interpolation = getattr(T.InterpolationMode, interpolation) primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)] if hflip > 0.0: primary_tfl.append(T.RandomHorizontalFlip(p=hflip))