`ultralytics 8.0.233` improve Classify train augmentations (#4546)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
Co-authored-by: andresinsitu <andres.rodriguez@ingenieriainsitu.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
pull/7293/head^2 v8.0.233
fatih 11 months ago committed by GitHub
parent 6218b82072
commit 73dbb41920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      README.md
  2. 10
      README.zh-CN.md
  3. 10
      docs/en/models/yolov8.md
  4. 10
      docs/en/tasks/classify.md
  5. 32
      docs/en/usage/cfg.md
  6. 42
      tests/test_python.py
  7. 2
      ultralytics/__init__.py
  8. 3
      ultralytics/cfg/default.yaml
  9. 194
      ultralytics/data/augment.py
  10. 34
      ultralytics/data/dataset.py
  11. 5
      ultralytics/engine/predictor.py
  12. 11
      ultralytics/models/yolo/classify/predict.py
  13. 4
      ultralytics/utils/torch_utils.py

@ -181,11 +181,11 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classify/) for usag
| Model | size<br><sup>(pixels) | acc<br><sup>top1 | acc<br><sup>top5 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>A100 TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) at 640 |
| -------------------------------------------------------------------------------------------- | --------------------- | ---------------- | ---------------- | ------------------------------ | ----------------------------------- | ------------------ | ------------------------ |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** values are model accuracies on the [ImageNet](https://www.image-net.org/) dataset validation set. <br>Reproduce by `yolo val classify data=path/to/ImageNet device=0`
- **Speed** averaged over ImageNet val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance. <br>Reproduce by `yolo val classify data=path/to/ImageNet batch=1 device=0|cpu`

@ -181,11 +181,11 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
| 模型 | 尺寸<br><sup>(像素) | acc<br><sup>top1 | acc<br><sup>top5 | 速度<br><sup>CPU ONNX<br>(ms) | 速度<br><sup>A100 TensorRT<br>(ms) | 参数<br><sup>(M) | FLOPs<br><sup>(B) at 640 |
| -------------------------------------------------------------------------------------------- | --------------- | ---------------- | ---------------- | --------------------------- | -------------------------------- | -------------- | ------------------------ |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** 值是模型在 [ImageNet](https://www.image-net.org/) 数据集验证集上的准确率。 <br>通过 `yolo val classify data=path/to/ImageNet device=0` 复现
- **速度** 是使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例对 ImageNet val 图像进行平均计算的。 <br>通过 `yolo val classify data=path/to/ImageNet batch=1 device=0|cpu` 复现

@ -91,11 +91,11 @@ This table provides an overview of the YOLOv8 model variants, highlighting their
| Model | size<br><sup>(pixels) | acc<br><sup>top1 | acc<br><sup>top5 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>A100 TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) at 640 |
| -------------------------------------------------------------------------------------------- | --------------------- | ---------------- | ---------------- | ------------------------------ | ----------------------------------- | ------------------ | ------------------------ |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
=== "Pose (COCO)"

@ -35,11 +35,11 @@ YOLOv8 pretrained Classify models are shown here. Detect, Segment and Pose model
| Model | size<br><sup>(pixels) | acc<br><sup>top1 | acc<br><sup>top5 | Speed<br><sup>CPU ONNX<br>(ms) | Speed<br><sup>A100 TensorRT<br>(ms) | params<br><sup>(M) | FLOPs<br><sup>(B) at 640 |
|----------------------------------------------------------------------------------------------|-----------------------|------------------|------------------|--------------------------------|-------------------------------------|--------------------|--------------------------|
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** values are model accuracies on the [ImageNet](https://www.image-net.org/) dataset validation set.
<br>Reproduce by `yolo val classify data=path/to/ImageNet device=0`

@ -224,21 +224,23 @@ Export settings for YOLO models encompass configurations and options related to
Augmentation settings for YOLO models refer to the various transformations and modifications applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO augmentation settings include the type and intensity of the transformations applied (e.g. random flips, rotations, cropping, color changes), the probability with which each transformation is applied, and the presence of additional features such as masks or multiple labels per box. Other factors that may affect the augmentation process include the size and composition of the original dataset and the specific task the model is being used for. It is important to carefully tune and experiment with these settings to ensure that the augmented dataset is diverse and representative enough to train a high-performing model.
| Key | Value | Description |
|---------------|---------|-------------------------------------------------|
| `hsv_h` | `0.015` | image HSV-Hue augmentation (fraction) |
| `hsv_s` | `0.7` | image HSV-Saturation augmentation (fraction) |
| `hsv_v` | `0.4` | image HSV-Value augmentation (fraction) |
| `degrees` | `0.0` | image rotation (+/- deg) |
| `translate` | `0.1` | image translation (+/- fraction) |
| `scale` | `0.5` | image scale (+/- gain) |
| `shear` | `0.0` | image shear (+/- deg) |
| `perspective` | `0.0` | image perspective (+/- fraction), range 0-0.001 |
| `flipud` | `0.0` | image flip up-down (probability) |
| `fliplr` | `0.5` | image flip left-right (probability) |
| `mosaic` | `1.0` | image mosaic (probability) |
| `mixup` | `0.0` | image mixup (probability) |
| `copy_paste` | `0.0` | segment copy-paste (probability) |
| Key | Value | Description |
|-----------------|-----------------|--------------------------------------------------------------------------------|
| `hsv_h` | `0.015` | image HSV-Hue augmentation (fraction) |
| `hsv_s` | `0.7` | image HSV-Saturation augmentation (fraction) |
| `hsv_v` | `0.4` | image HSV-Value augmentation (fraction) |
| `degrees` | `0.0` | image rotation (+/- deg) |
| `translate` | `0.1` | image translation (+/- fraction) |
| `scale` | `0.5` | image scale (+/- gain) |
| `shear` | `0.0` | image shear (+/- deg) |
| `perspective` | `0.0` | image perspective (+/- fraction), range 0-0.001 |
| `flipud` | `0.0` | image flip up-down (probability) |
| `fliplr` | `0.5` | image flip left-right (probability) |
| `mosaic` | `1.0` | image mosaic (probability) |
| `mixup` | `0.0` | image mixup (probability) |
| `copy_paste` | `0.0` | segment copy-paste (probability) |
| `auto_augment` | `'randaugment'` | auto augmentation policy for classification (randaugment, autoaugment, augmix) |
| `erasing` | `0.4` | probability o random erasing during classification training (0-1) training |
## Logging, checkpoints, plotting and file management

@ -505,6 +505,48 @@ def test_hub():
smart_request('GET', 'http://github.com', progress=True)
@pytest.fixture
def image():
return cv2.imread(str(SOURCE))
@pytest.mark.parametrize(
'auto_augment, erasing, force_color_jitter',
[
(None, 0.0, False),
('randaugment', 0.5, True),
('augmix', 0.2, False),
('autoaugment', 0.0, True), ],
)
def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):
import torchvision.transforms as T
from ultralytics.data.augment import classify_augmentations
transform = classify_augmentations(
size=224,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5),
scale=(0.08, 1.0),
ratio=(3.0 / 4.0, 4.0 / 3.0),
hflip=0.5,
vflip=0.5,
auto_augment=auto_augment,
hsv_h=0.015,
hsv_s=0.4,
hsv_v=0.4,
force_color_jitter=force_color_jitter,
erasing=erasing,
interpolation=T.InterpolationMode.BILINEAR,
)
transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
assert transformed_image.shape == (3, 224, 224)
assert torch.is_tensor(transformed_image)
assert transformed_image.dtype == torch.float32
@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
def test_model_tune():

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.232'
__version__ = '8.0.233'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM

@ -113,6 +113,9 @@ fliplr: 0.5 # (float) image flip left-right (probability)
mosaic: 1.0 # (float) image mosaic (probability)
mixup: 0.0 # (float) image mixup (probability)
copy_paste: 0.0 # (float) segment copy-paste (probability)
auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
erasing: 0.4 # (float) probability of random erasing during classification training (0-1)
crop_fraction: 1.0 # (float) image crop fraction for classification evaluation/inference (0-1)
# Custom config.yaml ---------------------------------------------------------------------------------------------------
cfg: # (str, optional) for overriding defaults.yaml

@ -14,9 +14,14 @@ from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box
from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
DEFAULT_MEAN = (0.0, 0.0, 0.0)
DEFAULT_STD = (1.0, 1.0, 1.0)
DEFAULT_CROP_FTACTION = 1.0
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
class BaseTransform:
@ -982,65 +987,144 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
# Classification augmentations -----------------------------------------------------------------------------------------
def classify_transforms(size=224, rect=False, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
"""Transforms to apply if albumentations not installed."""
def classify_transforms(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
crop_fraction: float = DEFAULT_CROP_FTACTION,
):
"""
Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py.
Args:
size (int): image size
mean (tuple): mean values of RGB channels
std (tuple): std values of RGB channels
interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
crop_fraction (float): fraction of image to crop. default is 1.0.
Returns:
T.Compose: torchvision transforms
"""
if isinstance(size, (tuple, list)):
assert len(size) == 2
scale_size = tuple([math.floor(x / crop_fraction) for x in size])
else:
scale_size = math.floor(size / crop_fraction)
scale_size = (scale_size, scale_size)
# aspect ratio is preserved, crops center within image, no borders are added, image is lost
if scale_size[0] == scale_size[1]:
# simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
else:
# resize shortest edge to matching target dim for non-square target
tfl = [T.Resize(scale_size)]
tfl += [T.CenterCrop(size)]
tfl += [T.ToTensor(), T.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
)]
return T.Compose(tfl)
# Classification augmentations train ---------------------------------------------------------------------------------------
def classify_augmentations(
size=224,
mean=DEFAULT_MEAN,
std=DEFAULT_STD,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.0,
auto_augment=None,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
force_color_jitter=False,
erasing=0.,
interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
):
"""
Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
Args:
size (int): image size
scale (tuple): scale range of the image. default is (0.08, 1.0)
ratio (tuple): aspect ratio range of the image. default is (3./4., 4./3.)
mean (tuple): mean values of RGB channels
std (tuple): std values of RGB channels
hflip (float): probability of horizontal flip
vflip (float): probability of vertical flip
auto_augment (str): auto augmentation policy. can be 'randaugment', 'augmix', 'autoaugment' or None.
hsv_h (float): image HSV-Hue augmentation (fraction)
hsv_s (float): image HSV-Saturation augmentation (fraction)
hsv_v (float): image HSV-Value augmentation (fraction)
contrast (float): image contrast augmentation (fraction)
force_color_jitter (bool): force to apply color jitter even if auto augment is enabled
erasing (float): probability of random erasing
interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
Returns:
T.Compose: torchvision transforms
"""
# Transforms to apply if albumentations not installed
if not isinstance(size, int):
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
transforms = [ClassifyLetterBox(size, auto=True) if rect else CenterCrop(size), ToTensor()]
if any(mean) or any(std):
transforms.append(T.Normalize(mean, std, inplace=True))
return T.Compose(transforms)
def hsv2colorjitter(h, s, v):
"""Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
return v, v, s, h
def classify_albumentations(
augment=True,
size=224,
scale=(0.08, 1.0),
hflip=0.5,
vflip=0.0,
hsv_h=0.015, # image HSV-Hue augmentation (fraction)
hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
hsv_v=0.4, # image HSV-Value augmentation (fraction)
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False,
):
"""YOLOv8 classification Albumentations (optional, only used if package is installed)."""
prefix = colorstr('albumentations: ')
try:
import albumentations as A
from albumentations.pytorch import ToTensorV2
check_version(A.__version__, '1.0.3', hard=True) # version requirement
if augment: # Resize and crop
T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
if auto_aug:
# TODO: implement AugMix, AutoAug & RandAug in albumentations
LOGGER.info(f'{prefix}auto augmentations are currently not supported')
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.:
primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
primary_tfl += [T.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
disable_color_jitter = False
if auto_augment:
assert isinstance(auto_augment, str)
# color jitter is typically disabled if AA/RA on,
# this allows override without breaking old hparm cfgs
disable_color_jitter = not force_color_jitter
if auto_augment == 'randaugment':
if TORCHVISION_0_11:
secondary_tfl += [T.RandAugment(interpolation=interpolation)]
else:
if hflip > 0:
T += [A.HorizontalFlip(p=hflip)]
if vflip > 0:
T += [A.VerticalFlip(p=vflip)]
if any((hsv_h, hsv_s, hsv_v)):
T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
else: # Use fixed crop for eval set (reproducibility)
T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
return A.Compose(T)
except ImportError: # package not installed, skip
pass
except Exception as e:
LOGGER.info(f'{prefix}{e}')
LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
elif auto_augment == 'augmix':
if TORCHVISION_0_13:
secondary_tfl += [T.AugMix(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
elif auto_augment == 'autoaugment':
if TORCHVISION_0_10:
secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
else:
LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
else:
raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
f'"augmix", "autoaugment" or None')
if not disable_color_jitter:
secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
final_tfl = [
T.ToTensor(),
T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
T.RandomErasing(p=erasing, inplace=True)]
return T.Compose(primary_tfl + secondary_tfl + final_tfl)
# NOTE: keep this class for backward compatibility
class ClassifyLetterBox:
"""
YOLOv8 LetterBox class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
@ -1091,6 +1175,7 @@ class ClassifyLetterBox:
return im_out
# NOTE: keep this class for backward compatibility
class CenterCrop:
"""YOLOv8 CenterCrop class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
T.Compose([CenterCrop(size), ToTensor()]).
@ -1117,6 +1202,7 @@ class CenterCrop:
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
# NOTE: keep this class for backward compatibility
class ToTensor:
"""YOLOv8 ToTensor class for image preprocessing, i.e., T.Compose([LetterBox(size), ToTensor()])."""

@ -8,10 +8,11 @@ import cv2
import numpy as np
import torch
import torchvision
from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
@ -225,19 +226,17 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
self.cache_disk = cache == 'disk'
self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
self.torch_transforms = classify_transforms(args.imgsz, rect=args.rect)
self.album_transforms = classify_albumentations(
augment=augment,
size=args.imgsz,
scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
hflip=args.fliplr,
vflip=args.flipud,
hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
std=(1.0, 1.0, 1.0), # IMAGENET_STD
auto_aug=False) if augment else None
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
self.torch_transforms = classify_augmentations(size=args.imgsz,
scale=scale,
hflip=args.fliplr,
vflip=args.flipud,
erasing=args.erasing,
auto_augment=args.auto_augment,
hsv_h=args.hsv_h,
hsv_s=args.hsv_s,
hsv_v=args.hsv_v) if augment else classify_transforms(
size=args.imgsz, crop_fraction=args.crop_fraction)
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
@ -250,10 +249,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
if self.album_transforms:
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
else:
sample = self.torch_transforms(im)
# Convert NumPy array to PIL image
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
sample = self.torch_transforms(im)
return {'img': sample, 'cls': j}
def __len__(self) -> int:

@ -210,8 +210,9 @@ class BasePredictor:
def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
self.imgsz[0])) if self.args.task == 'classify' else None
self.transforms = getattr(
self.model.model, 'transforms', classify_transforms(
self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None
self.dataset = load_inference_source(source=source,
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,

@ -1,6 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import cv2
import torch
from PIL import Image
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
@ -29,11 +31,18 @@ class ClassificationPredictor(BasePredictor):
"""Initializes ClassificationPredictor setting the task to 'classify'."""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor'
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
img = torch.stack([self.transforms(im) for im in img], dim=0)
is_legacy_transform = any(self._legacy_transform_name in str(transform)
for transform in self.transforms.transforms)
if is_legacy_transform: # to handle legacy transforms
img = torch.stack([self.transforms(im) for im in img], dim=0)
else:
img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img],
dim=0)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32

@ -15,6 +15,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.utils.checks import check_version
@ -26,6 +27,9 @@ except ImportError:
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_2_0 = check_version(torch.__version__, '2.0.0')
TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
TORCHVISION_0_11 = check_version(torchvision.__version__, '0.11.0')
TORCHVISION_0_13 = check_version(torchvision.__version__, '0.13.0')
@contextmanager

Loading…
Cancel
Save