Replace `+=` with faster list `.append()` (#13849)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/13862/head
Kayzwer 5 months ago committed by GitHub
parent 105edd4dc1
commit c497732278
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 29
      ultralytics/data/augment.py

@ -1223,16 +1223,13 @@ def classify_transforms(
else:
# Resize the shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]
tfl += [
T.ToTensor(),
T.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
),
]
tfl.extend(
[
T.CenterCrop(size),
T.ToTensor(),
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
]
)
return T.Compose(tfl)
@ -1284,9 +1281,9 @@ def classify_augmentations(
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.0:
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
primary_tfl.append(T.RandomHorizontalFlip(p=hflip))
if vflip > 0.0:
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
primary_tfl.append(T.RandomVerticalFlip(p=vflip))
secondary_tfl = []
disable_color_jitter = False
@ -1298,19 +1295,19 @@ def classify_augmentations(
if auto_augment == "randaugment":
if TORCHVISION_0_11:
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
secondary_tfl.append(T.RandAugment(interpolation=interpolation))
else:
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
elif auto_augment == "augmix":
if TORCHVISION_0_13:
secondary_tfl += [T.AugMix(interpolation=interpolation)]
secondary_tfl.append(T.AugMix(interpolation=interpolation))
else:
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
elif auto_augment == "autoaugment":
if TORCHVISION_0_10:
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
secondary_tfl.append(T.AutoAugment(interpolation=interpolation))
else:
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
@ -1321,7 +1318,7 @@ def classify_augmentations(
)
if not disable_color_jitter:
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
secondary_tfl.append(T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h))
final_tfl = [
T.ToTensor(),

Loading…
Cancel
Save