|
|
|
@ -226,35 +226,42 @@ class YOLODataset(BaseDataset): |
|
|
|
|
# Classification dataloaders ------------------------------------------------------------------------------------------- |
|
|
|
|
class ClassificationDataset(torchvision.datasets.ImageFolder): |
|
|
|
|
""" |
|
|
|
|
YOLO Classification Dataset. |
|
|
|
|
Extends torchvision ImageFolder to support YOLO classification tasks, offering functionalities like image |
|
|
|
|
augmentation, caching, and verification. It's designed to efficiently handle large datasets for training deep |
|
|
|
|
learning models, with optional image transformations and caching mechanisms to speed up training. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
root (str): Dataset path. |
|
|
|
|
This class allows for augmentations using both torchvision and Albumentations libraries, and supports caching images |
|
|
|
|
in RAM or on disk to reduce IO overhead during training. Additionally, it implements a robust verification process |
|
|
|
|
to ensure data integrity and consistency. |
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
|
cache_ram (bool): True if images should be cached in RAM, False otherwise. |
|
|
|
|
cache_disk (bool): True if images should be cached on disk, False otherwise. |
|
|
|
|
samples (list): List of samples containing file, index, npy, and im. |
|
|
|
|
torch_transforms (callable): torchvision transforms applied to the dataset. |
|
|
|
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. |
|
|
|
|
cache_ram (bool): Indicates if caching in RAM is enabled. |
|
|
|
|
cache_disk (bool): Indicates if caching on disk is enabled. |
|
|
|
|
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache |
|
|
|
|
file (if caching on disk), and optionally the loaded image array (if caching in RAM). |
|
|
|
|
torch_transforms (callable): PyTorch transforms to be applied to the images. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, root, args, augment=False, cache=False, prefix=""): |
|
|
|
|
def __init__(self, root, args, augment=False, prefix=""): |
|
|
|
|
""" |
|
|
|
|
Initialize YOLO object with root, image size, augmentations, and cache settings. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
root (str): Dataset path. |
|
|
|
|
args (Namespace): Argument parser containing dataset related settings. |
|
|
|
|
augment (bool, optional): True if dataset should be augmented, False otherwise. Defaults to False. |
|
|
|
|
cache (bool | str | optional): Cache setting, can be True, False, 'ram' or 'disk'. Defaults to False. |
|
|
|
|
root (str): Path to the dataset directory where images are stored in a class-specific folder structure. |
|
|
|
|
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation |
|
|
|
|
parameters, and cache settings. It includes attributes like `imgsz` (image size), `fraction` (fraction |
|
|
|
|
of data to use), `scale`, `fliplr`, `flipud`, `cache` (disk or RAM caching for faster training), |
|
|
|
|
`auto_augment`, `hsv_h`, `hsv_s`, `hsv_v`, and `crop_fraction`. |
|
|
|
|
augment (bool, optional): Whether to apply augmentations to the dataset. Default is False. |
|
|
|
|
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification and |
|
|
|
|
debugging. Default is an empty string. |
|
|
|
|
""" |
|
|
|
|
super().__init__(root=root) |
|
|
|
|
if augment and args.fraction < 1.0: # reduce training fraction |
|
|
|
|
self.samples = self.samples[: round(len(self.samples) * args.fraction)] |
|
|
|
|
self.prefix = colorstr(f"{prefix}: ") if prefix else "" |
|
|
|
|
self.cache_ram = cache is True or cache == "ram" |
|
|
|
|
self.cache_disk = cache == "disk" |
|
|
|
|
self.cache_ram = args.cache is True or args.cache == "ram" # cache images into RAM |
|
|
|
|
self.cache_disk = args.cache == "disk" # cache images on hard drive as uncompressed *.npy files |
|
|
|
|
self.samples = self.verify_images() # filter out bad images |
|
|
|
|
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im |
|
|
|
|
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0) |
|
|
|
|