diff --git a/README.md b/README.md
index fcc51b4fd..e798a33ab 100644
--- a/README.md
+++ b/README.md
@@ -181,11 +181,11 @@ See [Classification Docs](https://docs.ultralytics.com/tasks/classify/) for usag
| Model | size
(pixels) | acc
top1 | acc
top5 | Speed
CPU ONNX
(ms) | Speed
A100 TensorRT
(ms) | params
(M) | FLOPs
(B) at 640 |
| -------------------------------------------------------------------------------------------- | --------------------- | ---------------- | ---------------- | ------------------------------ | ----------------------------------- | ------------------ | ------------------------ |
-| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
-| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
-| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
-| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
-| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
+| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
+| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
+| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
+| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
+| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** values are model accuracies on the [ImageNet](https://www.image-net.org/) dataset validation set.
Reproduce by `yolo val classify data=path/to/ImageNet device=0`
- **Speed** averaged over ImageNet val images using an [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) instance.
Reproduce by `yolo val classify data=path/to/ImageNet batch=1 device=0|cpu`
diff --git a/README.zh-CN.md b/README.zh-CN.md
index 49f7f2baf..cfad61aab 100644
--- a/README.zh-CN.md
+++ b/README.zh-CN.md
@@ -181,11 +181,11 @@ success = model.export(format="onnx") # 将模型导出为 ONNX 格式
| 模型 | 尺寸
(像素) | acc
top1 | acc
top5 | 速度
CPU ONNX
(ms) | 速度
A100 TensorRT
(ms) | 参数
(M) | FLOPs
(B) at 640 |
| -------------------------------------------------------------------------------------------- | --------------- | ---------------- | ---------------- | --------------------------- | -------------------------------- | -------------- | ------------------------ |
-| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
-| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
-| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
-| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
-| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
+| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
+| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
+| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
+| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
+| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** 值是模型在 [ImageNet](https://www.image-net.org/) 数据集验证集上的准确率。
通过 `yolo val classify data=path/to/ImageNet device=0` 复现
- **速度** 是使用 [Amazon EC2 P4d](https://aws.amazon.com/ec2/instance-types/p4/) 实例对 ImageNet val 图像进行平均计算的。
通过 `yolo val classify data=path/to/ImageNet batch=1 device=0|cpu` 复现
diff --git a/docs/en/models/yolov8.md b/docs/en/models/yolov8.md
index fd7013028..624bd286a 100644
--- a/docs/en/models/yolov8.md
+++ b/docs/en/models/yolov8.md
@@ -91,11 +91,11 @@ This table provides an overview of the YOLOv8 model variants, highlighting their
| Model | size
(pixels) | acc
top1 | acc
top5 | Speed
CPU ONNX
(ms) | Speed
A100 TensorRT
(ms) | params
(M) | FLOPs
(B) at 640 |
| -------------------------------------------------------------------------------------------- | --------------------- | ---------------- | ---------------- | ------------------------------ | ----------------------------------- | ------------------ | ------------------------ |
- | [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
- | [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
- | [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
- | [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
- | [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
+ | [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
+ | [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
+ | [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
+ | [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
+ | [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
=== "Pose (COCO)"
diff --git a/docs/en/tasks/classify.md b/docs/en/tasks/classify.md
index fc5ec085f..dc74a68eb 100644
--- a/docs/en/tasks/classify.md
+++ b/docs/en/tasks/classify.md
@@ -35,11 +35,11 @@ YOLOv8 pretrained Classify models are shown here. Detect, Segment and Pose model
| Model | size
(pixels) | acc
top1 | acc
top5 | Speed
CPU ONNX
(ms) | Speed
A100 TensorRT
(ms) | params
(M) | FLOPs
(B) at 640 |
|----------------------------------------------------------------------------------------------|-----------------------|------------------|------------------|--------------------------------|-------------------------------------|--------------------|--------------------------|
-| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 66.6 | 87.0 | 12.9 | 0.31 | 2.7 | 4.3 |
-| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 72.3 | 91.1 | 23.4 | 0.35 | 6.4 | 13.5 |
-| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.4 | 93.2 | 85.4 | 0.62 | 17.0 | 42.7 |
-| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 78.0 | 94.1 | 163.0 | 0.87 | 37.5 | 99.7 |
-| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 78.4 | 94.3 | 232.0 | 1.01 | 57.4 | 154.8 |
+| [YOLOv8n-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-cls.pt) | 224 | 69.0 | 88.3 | 12.9 | 0.31 | 2.7 | 4.3 |
+| [YOLOv8s-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-cls.pt) | 224 | 73.8 | 91.7 | 23.4 | 0.35 | 6.4 | 13.5 |
+| [YOLOv8m-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-cls.pt) | 224 | 76.8 | 93.5 | 85.4 | 0.62 | 17.0 | 42.7 |
+| [YOLOv8l-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-cls.pt) | 224 | 76.8 | 93.5 | 163.0 | 0.87 | 37.5 | 99.7 |
+| [YOLOv8x-cls](https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-cls.pt) | 224 | 79.0 | 94.6 | 232.0 | 1.01 | 57.4 | 154.8 |
- **acc** values are model accuracies on the [ImageNet](https://www.image-net.org/) dataset validation set.
Reproduce by `yolo val classify data=path/to/ImageNet device=0`
diff --git a/docs/en/usage/cfg.md b/docs/en/usage/cfg.md
index 2e822aeeb..544dd38de 100644
--- a/docs/en/usage/cfg.md
+++ b/docs/en/usage/cfg.md
@@ -224,21 +224,23 @@ Export settings for YOLO models encompass configurations and options related to
Augmentation settings for YOLO models refer to the various transformations and modifications applied to the training data to increase the diversity and size of the dataset. These settings can affect the model's performance, speed, and accuracy. Some common YOLO augmentation settings include the type and intensity of the transformations applied (e.g. random flips, rotations, cropping, color changes), the probability with which each transformation is applied, and the presence of additional features such as masks or multiple labels per box. Other factors that may affect the augmentation process include the size and composition of the original dataset and the specific task the model is being used for. It is important to carefully tune and experiment with these settings to ensure that the augmented dataset is diverse and representative enough to train a high-performing model.
-| Key | Value | Description |
-|---------------|---------|-------------------------------------------------|
-| `hsv_h` | `0.015` | image HSV-Hue augmentation (fraction) |
-| `hsv_s` | `0.7` | image HSV-Saturation augmentation (fraction) |
-| `hsv_v` | `0.4` | image HSV-Value augmentation (fraction) |
-| `degrees` | `0.0` | image rotation (+/- deg) |
-| `translate` | `0.1` | image translation (+/- fraction) |
-| `scale` | `0.5` | image scale (+/- gain) |
-| `shear` | `0.0` | image shear (+/- deg) |
-| `perspective` | `0.0` | image perspective (+/- fraction), range 0-0.001 |
-| `flipud` | `0.0` | image flip up-down (probability) |
-| `fliplr` | `0.5` | image flip left-right (probability) |
-| `mosaic` | `1.0` | image mosaic (probability) |
-| `mixup` | `0.0` | image mixup (probability) |
-| `copy_paste` | `0.0` | segment copy-paste (probability) |
+| Key | Value | Description |
+|-----------------|-----------------|--------------------------------------------------------------------------------|
+| `hsv_h` | `0.015` | image HSV-Hue augmentation (fraction) |
+| `hsv_s` | `0.7` | image HSV-Saturation augmentation (fraction) |
+| `hsv_v` | `0.4` | image HSV-Value augmentation (fraction) |
+| `degrees` | `0.0` | image rotation (+/- deg) |
+| `translate` | `0.1` | image translation (+/- fraction) |
+| `scale` | `0.5` | image scale (+/- gain) |
+| `shear` | `0.0` | image shear (+/- deg) |
+| `perspective` | `0.0` | image perspective (+/- fraction), range 0-0.001 |
+| `flipud` | `0.0` | image flip up-down (probability) |
+| `fliplr` | `0.5` | image flip left-right (probability) |
+| `mosaic` | `1.0` | image mosaic (probability) |
+| `mixup` | `0.0` | image mixup (probability) |
+| `copy_paste` | `0.0` | segment copy-paste (probability) |
+| `auto_augment` | `'randaugment'` | auto augmentation policy for classification (randaugment, autoaugment, augmix) |
+| `erasing` | `0.4` | probability o random erasing during classification training (0-1) training |
## Logging, checkpoints, plotting and file management
diff --git a/tests/test_python.py b/tests/test_python.py
index 710974d42..8032602ef 100644
--- a/tests/test_python.py
+++ b/tests/test_python.py
@@ -505,6 +505,48 @@ def test_hub():
smart_request('GET', 'http://github.com', progress=True)
+@pytest.fixture
+def image():
+ return cv2.imread(str(SOURCE))
+
+
+@pytest.mark.parametrize(
+ 'auto_augment, erasing, force_color_jitter',
+ [
+ (None, 0.0, False),
+ ('randaugment', 0.5, True),
+ ('augmix', 0.2, False),
+ ('autoaugment', 0.0, True), ],
+)
+def test_classify_transforms_train(image, auto_augment, erasing, force_color_jitter):
+ import torchvision.transforms as T
+
+ from ultralytics.data.augment import classify_augmentations
+
+ transform = classify_augmentations(
+ size=224,
+ mean=(0.5, 0.5, 0.5),
+ std=(0.5, 0.5, 0.5),
+ scale=(0.08, 1.0),
+ ratio=(3.0 / 4.0, 4.0 / 3.0),
+ hflip=0.5,
+ vflip=0.5,
+ auto_augment=auto_augment,
+ hsv_h=0.015,
+ hsv_s=0.4,
+ hsv_v=0.4,
+ force_color_jitter=force_color_jitter,
+ erasing=erasing,
+ interpolation=T.InterpolationMode.BILINEAR,
+ )
+
+ transformed_image = transform(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
+
+ assert transformed_image.shape == (3, 224, 224)
+ assert torch.is_tensor(transformed_image)
+ assert transformed_image.dtype == torch.float32
+
+
@pytest.mark.slow
@pytest.mark.skipif(not ONLINE, reason='environment is offline')
def test_model_tune():
diff --git a/ultralytics/__init__.py b/ultralytics/__init__.py
index 7ad3a039e..ad5d2484b 100644
--- a/ultralytics/__init__.py
+++ b/ultralytics/__init__.py
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
-__version__ = '8.0.232'
+__version__ = '8.0.233'
from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM
diff --git a/ultralytics/cfg/default.yaml b/ultralytics/cfg/default.yaml
index f6edad234..ead9735e6 100644
--- a/ultralytics/cfg/default.yaml
+++ b/ultralytics/cfg/default.yaml
@@ -113,6 +113,9 @@ fliplr: 0.5 # (float) image flip left-right (probability)
mosaic: 1.0 # (float) image mosaic (probability)
mixup: 0.0 # (float) image mixup (probability)
copy_paste: 0.0 # (float) segment copy-paste (probability)
+auto_augment: randaugment # (str) auto augmentation policy for classification (randaugment, autoaugment, augmix)
+erasing: 0.4 # (float) probability of random erasing during classification training (0-1)
+crop_fraction: 1.0 # (float) image crop fraction for classification evaluation/inference (0-1)
# Custom config.yaml ---------------------------------------------------------------------------------------------------
cfg: # (str, optional) for overriding defaults.yaml
diff --git a/ultralytics/data/augment.py b/ultralytics/data/augment.py
index cd6a5465e..97eb5fdba 100644
--- a/ultralytics/data/augment.py
+++ b/ultralytics/data/augment.py
@@ -14,9 +14,14 @@ from ultralytics.utils.checks import check_version
from ultralytics.utils.instance import Instances
from ultralytics.utils.metrics import bbox_ioa
from ultralytics.utils.ops import segment2box
+from ultralytics.utils.torch_utils import TORCHVISION_0_10, TORCHVISION_0_11, TORCHVISION_0_13
from .utils import polygons2masks, polygons2masks_overlap
+DEFAULT_MEAN = (0.0, 0.0, 0.0)
+DEFAULT_STD = (1.0, 1.0, 1.0)
+DEFAULT_CROP_FTACTION = 1.0
+
# TODO: we might need a BaseTransform to make all these augments be compatible with both classification and semantic
class BaseTransform:
@@ -982,65 +987,144 @@ def v8_transforms(dataset, imgsz, hyp, stretch=False):
# Classification augmentations -----------------------------------------------------------------------------------------
-def classify_transforms(size=224, rect=False, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)): # IMAGENET_MEAN, IMAGENET_STD
- """Transforms to apply if albumentations not installed."""
+def classify_transforms(
+ size=224,
+ mean=DEFAULT_MEAN,
+ std=DEFAULT_STD,
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
+ crop_fraction: float = DEFAULT_CROP_FTACTION,
+):
+ """
+ Classification transforms for evaluation/inference. Inspired by timm/data/transforms_factory.py.
+
+ Args:
+ size (int): image size
+ mean (tuple): mean values of RGB channels
+ std (tuple): std values of RGB channels
+ interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
+ crop_fraction (float): fraction of image to crop. default is 1.0.
+
+ Returns:
+ T.Compose: torchvision transforms
+ """
+
+ if isinstance(size, (tuple, list)):
+ assert len(size) == 2
+ scale_size = tuple([math.floor(x / crop_fraction) for x in size])
+ else:
+ scale_size = math.floor(size / crop_fraction)
+ scale_size = (scale_size, scale_size)
+
+ # aspect ratio is preserved, crops center within image, no borders are added, image is lost
+ if scale_size[0] == scale_size[1]:
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
+ tfl = [T.Resize(scale_size[0], interpolation=interpolation)]
+ else:
+ # resize shortest edge to matching target dim for non-square target
+ tfl = [T.Resize(scale_size)]
+ tfl += [T.CenterCrop(size)]
+
+ tfl += [T.ToTensor(), T.Normalize(
+ mean=torch.tensor(mean),
+ std=torch.tensor(std),
+ )]
+
+ return T.Compose(tfl)
+
+
+# Classification augmentations train ---------------------------------------------------------------------------------------
+def classify_augmentations(
+ size=224,
+ mean=DEFAULT_MEAN,
+ std=DEFAULT_STD,
+ scale=None,
+ ratio=None,
+ hflip=0.5,
+ vflip=0.0,
+ auto_augment=None,
+ hsv_h=0.015, # image HSV-Hue augmentation (fraction)
+ hsv_s=0.4, # image HSV-Saturation augmentation (fraction)
+ hsv_v=0.4, # image HSV-Value augmentation (fraction)
+ force_color_jitter=False,
+ erasing=0.,
+ interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR,
+):
+ """
+ Classification transforms with augmentation for training. Inspired by timm/data/transforms_factory.py.
+
+ Args:
+ size (int): image size
+ scale (tuple): scale range of the image. default is (0.08, 1.0)
+ ratio (tuple): aspect ratio range of the image. default is (3./4., 4./3.)
+ mean (tuple): mean values of RGB channels
+ std (tuple): std values of RGB channels
+ hflip (float): probability of horizontal flip
+ vflip (float): probability of vertical flip
+ auto_augment (str): auto augmentation policy. can be 'randaugment', 'augmix', 'autoaugment' or None.
+ hsv_h (float): image HSV-Hue augmentation (fraction)
+ hsv_s (float): image HSV-Saturation augmentation (fraction)
+ hsv_v (float): image HSV-Value augmentation (fraction)
+ contrast (float): image contrast augmentation (fraction)
+ force_color_jitter (bool): force to apply color jitter even if auto augment is enabled
+ erasing (float): probability of random erasing
+ interpolation (T.InterpolationMode): interpolation mode. default is T.InterpolationMode.BILINEAR.
+
+ Returns:
+ T.Compose: torchvision transforms
+ """
+ # Transforms to apply if albumentations not installed
if not isinstance(size, int):
raise TypeError(f'classify_transforms() size {size} must be integer, not (list, tuple)')
- transforms = [ClassifyLetterBox(size, auto=True) if rect else CenterCrop(size), ToTensor()]
- if any(mean) or any(std):
- transforms.append(T.Normalize(mean, std, inplace=True))
- return T.Compose(transforms)
-
-
-def hsv2colorjitter(h, s, v):
- """Map HSV (hue, saturation, value) jitter into ColorJitter values (brightness, contrast, saturation, hue)"""
- return v, v, s, h
-
-
-def classify_albumentations(
- augment=True,
- size=224,
- scale=(0.08, 1.0),
- hflip=0.5,
- vflip=0.0,
- hsv_h=0.015, # image HSV-Hue augmentation (fraction)
- hsv_s=0.7, # image HSV-Saturation augmentation (fraction)
- hsv_v=0.4, # image HSV-Value augmentation (fraction)
- mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
- std=(1.0, 1.0, 1.0), # IMAGENET_STD
- auto_aug=False,
-):
- """YOLOv8 classification Albumentations (optional, only used if package is installed)."""
- prefix = colorstr('albumentations: ')
- try:
- import albumentations as A
- from albumentations.pytorch import ToTensorV2
-
- check_version(A.__version__, '1.0.3', hard=True) # version requirement
- if augment: # Resize and crop
- T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
- if auto_aug:
- # TODO: implement AugMix, AutoAug & RandAug in albumentations
- LOGGER.info(f'{prefix}auto augmentations are currently not supported')
+ scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
+ ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
+ primary_tfl = [T.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation)]
+ if hflip > 0.:
+ primary_tfl += [T.RandomHorizontalFlip(p=hflip)]
+ if vflip > 0.:
+ primary_tfl += [T.RandomVerticalFlip(p=vflip)]
+
+ secondary_tfl = []
+ disable_color_jitter = False
+ if auto_augment:
+ assert isinstance(auto_augment, str)
+ # color jitter is typically disabled if AA/RA on,
+ # this allows override without breaking old hparm cfgs
+ disable_color_jitter = not force_color_jitter
+
+ if auto_augment == 'randaugment':
+ if TORCHVISION_0_11:
+ secondary_tfl += [T.RandAugment(interpolation=interpolation)]
else:
- if hflip > 0:
- T += [A.HorizontalFlip(p=hflip)]
- if vflip > 0:
- T += [A.VerticalFlip(p=vflip)]
- if any((hsv_h, hsv_s, hsv_v)):
- T += [A.ColorJitter(*hsv2colorjitter(hsv_h, hsv_s, hsv_v))] # brightness, contrast, saturation, hue
- else: # Use fixed crop for eval set (reproducibility)
- T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
- T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
- LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
- return A.Compose(T)
-
- except ImportError: # package not installed, skip
- pass
- except Exception as e:
- LOGGER.info(f'{prefix}{e}')
+ LOGGER.warning('"auto_augment=randaugment" requires torchvision >= 0.11.0. Disabling it.')
+
+ elif auto_augment == 'augmix':
+ if TORCHVISION_0_13:
+ secondary_tfl += [T.AugMix(interpolation=interpolation)]
+ else:
+ LOGGER.warning('"auto_augment=augmix" requires torchvision >= 0.13.0. Disabling it.')
+
+ elif auto_augment == 'autoaugment':
+ if TORCHVISION_0_10:
+ secondary_tfl += [T.AutoAugment(interpolation=interpolation)]
+ else:
+ LOGGER.warning('"auto_augment=autoaugment" requires torchvision >= 0.10.0. Disabling it.')
+
+ else:
+ raise ValueError(f'Invalid auto_augment policy: {auto_augment}. Should be one of "randaugment", '
+ f'"augmix", "autoaugment" or None')
+
+ if not disable_color_jitter:
+ secondary_tfl += [T.ColorJitter(brightness=hsv_v, contrast=hsv_v, saturation=hsv_s, hue=hsv_h)]
+
+ final_tfl = [
+ T.ToTensor(),
+ T.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
+ T.RandomErasing(p=erasing, inplace=True)]
+
+ return T.Compose(primary_tfl + secondary_tfl + final_tfl)
+# NOTE: keep this class for backward compatibility
class ClassifyLetterBox:
"""
YOLOv8 LetterBox class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
@@ -1091,6 +1175,7 @@ class ClassifyLetterBox:
return im_out
+# NOTE: keep this class for backward compatibility
class CenterCrop:
"""YOLOv8 CenterCrop class for image preprocessing, designed to be part of a transformation pipeline, e.g.,
T.Compose([CenterCrop(size), ToTensor()]).
@@ -1117,6 +1202,7 @@ class CenterCrop:
return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
+# NOTE: keep this class for backward compatibility
class ToTensor:
"""YOLOv8 ToTensor class for image preprocessing, i.e., T.Compose([LetterBox(size), ToTensor()])."""
diff --git a/ultralytics/data/dataset.py b/ultralytics/data/dataset.py
index 068311efc..538a1ff2a 100644
--- a/ultralytics/data/dataset.py
+++ b/ultralytics/data/dataset.py
@@ -8,10 +8,11 @@ import cv2
import numpy as np
import torch
import torchvision
+from PIL import Image
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM, colorstr, is_dir_writeable
-from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms
+from .augment import Compose, Format, Instances, LetterBox, classify_augmentations, classify_transforms, v8_transforms
from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
@@ -225,19 +226,17 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
self.cache_disk = cache == 'disk'
self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
- self.torch_transforms = classify_transforms(args.imgsz, rect=args.rect)
- self.album_transforms = classify_albumentations(
- augment=augment,
- size=args.imgsz,
- scale=(1.0 - args.scale, 1.0), # (0.08, 1.0)
- hflip=args.fliplr,
- vflip=args.flipud,
- hsv_h=args.hsv_h, # HSV-Hue augmentation (fraction)
- hsv_s=args.hsv_s, # HSV-Saturation augmentation (fraction)
- hsv_v=args.hsv_v, # HSV-Value augmentation (fraction)
- mean=(0.0, 0.0, 0.0), # IMAGENET_MEAN
- std=(1.0, 1.0, 1.0), # IMAGENET_STD
- auto_aug=False) if augment else None
+ scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
+ self.torch_transforms = classify_augmentations(size=args.imgsz,
+ scale=scale,
+ hflip=args.fliplr,
+ vflip=args.flipud,
+ erasing=args.erasing,
+ auto_augment=args.auto_augment,
+ hsv_h=args.hsv_h,
+ hsv_s=args.hsv_s,
+ hsv_v=args.hsv_v) if augment else classify_transforms(
+ size=args.imgsz, crop_fraction=args.crop_fraction)
def __getitem__(self, i):
"""Returns subset of data and targets corresponding to given indices."""
@@ -250,10 +249,9 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
im = np.load(fn)
else: # read image
im = cv2.imread(f) # BGR
- if self.album_transforms:
- sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))['image']
- else:
- sample = self.torch_transforms(im)
+ # Convert NumPy array to PIL image
+ im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
+ sample = self.torch_transforms(im)
return {'img': sample, 'cls': j}
def __len__(self) -> int:
diff --git a/ultralytics/engine/predictor.py b/ultralytics/engine/predictor.py
index cda599441..f78de2fa8 100644
--- a/ultralytics/engine/predictor.py
+++ b/ultralytics/engine/predictor.py
@@ -210,8 +210,9 @@ class BasePredictor:
def setup_source(self, source):
"""Sets up source and inference mode."""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
- self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
- self.imgsz[0])) if self.args.task == 'classify' else None
+ self.transforms = getattr(
+ self.model.model, 'transforms', classify_transforms(
+ self.imgsz[0], crop_fraction=self.args.crop_fraction)) if self.args.task == 'classify' else None
self.dataset = load_inference_source(source=source,
imgsz=self.imgsz,
vid_stride=self.args.vid_stride,
diff --git a/ultralytics/models/yolo/classify/predict.py b/ultralytics/models/yolo/classify/predict.py
index ca463b67f..9047d8df3 100644
--- a/ultralytics/models/yolo/classify/predict.py
+++ b/ultralytics/models/yolo/classify/predict.py
@@ -1,6 +1,8 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
+import cv2
import torch
+from PIL import Image
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
@@ -29,11 +31,18 @@ class ClassificationPredictor(BasePredictor):
"""Initializes ClassificationPredictor setting the task to 'classify'."""
super().__init__(cfg, overrides, _callbacks)
self.args.task = 'classify'
+ self._legacy_transform_name = 'ultralytics.yolo.data.augment.ToTensor'
def preprocess(self, img):
"""Converts input image to model-compatible data type."""
if not isinstance(img, torch.Tensor):
- img = torch.stack([self.transforms(im) for im in img], dim=0)
+ is_legacy_transform = any(self._legacy_transform_name in str(transform)
+ for transform in self.transforms.transforms)
+ if is_legacy_transform: # to handle legacy transforms
+ img = torch.stack([self.transforms(im) for im in img], dim=0)
+ else:
+ img = torch.stack([self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img],
+ dim=0)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
index 3a17a9aca..e8051571e 100644
--- a/ultralytics/utils/torch_utils.py
+++ b/ultralytics/utils/torch_utils.py
@@ -15,6 +15,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
+import torchvision
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
from ultralytics.utils.checks import check_version
@@ -26,6 +27,9 @@ except ImportError:
TORCH_1_9 = check_version(torch.__version__, '1.9.0')
TORCH_2_0 = check_version(torch.__version__, '2.0.0')
+TORCHVISION_0_10 = check_version(torchvision.__version__, '0.10.0')
+TORCHVISION_0_11 = check_version(torchvision.__version__, '0.11.0')
+TORCHVISION_0_13 = check_version(torchvision.__version__, '0.13.0')
@contextmanager