`ultralytics 8.1.17` fix `ClassificationDataset` caching (#8358)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/8392/head v8.1.17
Glenn Jocher 1 year ago committed by GitHub
parent 604b9d0794
commit 2945cfc6ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      mkdocs.yml
  2. 2
      ultralytics/__init__.py
  3. 37
      ultralytics/data/dataset.py

@ -197,7 +197,7 @@ nav:
- Python: usage/python.md
- Callbacks: usage/callbacks.md
- Configuration: usage/cfg.md
- Simple-Utilities: usage/simple-utilities.md
- Simple Utilities: usage/simple-utilities.md
- Advanced Customization: usage/engine.md
- Modes:
- modes/index.md

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.1.16"
__version__ = "8.1.17"
from ultralytics.data.explorer.explorer import Explorer
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld

@ -226,35 +226,42 @@ class YOLODataset(BaseDataset):
# Classification dataloaders -------------------------------------------------------------------------------------------
class ClassificationDataset(torchvision.datasets.ImageFolder):
"""
YOLO Classification Dataset.
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep
learning models, with optional image transformations and caching mechanisms to speed up training.
Args:
root (str): Dataset path.
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process
to ensure data integrity and consistency.
Attributes:
cache_ram (bool): True if images should be cached in RAM, False otherwise.
cache_disk (bool): True if images should be cached on disk, False otherwise.
samples (list): List of samples containing file, index, npy, and im.
torch_transforms (callable): torchvision transforms applied to the dataset.
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True.
cache_ram (bool): Indicates if caching in RAM is enabled.
cache_disk (bool): Indicates if caching on disk is enabled.
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
torch_transforms (callable): PyTorch transforms to be applied to the images.
"""
def __init__(self, root, args, augment=False, cache=False, prefix=""):
def __init__(self, root, args, augment=False, prefix=""):
"""
Initialize YOLO object with root, image size, augmentations, and cache settings.
Args:
root (str): Dataset path.
args (Namespace): Argument parser containing dataset related settings.
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False.
cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False.
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training),
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`.
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False.
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and
debugging. Default is an empty string.
"""
super().__init__(root=root)
if augment and args.fraction < 1.0: # reduce training fraction
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
self.cache_ram = cache is True or cache == "ram"
self.cache_disk = cache == "disk"
self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM
self.cache_disk = args.cache == "disk" # cache images on hard drive as uncompressed *.npy files
self.samples = self.verify_images() # filter out bad images
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)

Loading…
Cancel
Save