|
|
|
@ -1,5 +1,5 @@ |
|
|
|
|
# Ultralytics YOLO 🚀, AGPL-3.0 license |
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
|
from itertools import repeat |
|
|
|
|
from multiprocessing.pool import ThreadPool |
|
|
|
|
from pathlib import Path |
|
|
|
@ -10,11 +10,14 @@ import torch |
|
|
|
|
import torchvision |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, is_dir_writeable |
|
|
|
|
from ultralytics.utils import LOCAL_RANK, NUM_THREADS, TQDM_BAR_FORMAT, colorstr, is_dir_writeable |
|
|
|
|
|
|
|
|
|
from .augment import Compose, Format, Instances, LetterBox, classify_albumentations, classify_transforms, v8_transforms |
|
|
|
|
from .base import BaseDataset |
|
|
|
|
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image_label |
|
|
|
|
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label |
|
|
|
|
|
|
|
|
|
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 |
|
|
|
|
DATASET_CACHE_VERSION = '1.0.2' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class YOLODataset(BaseDataset): |
|
|
|
@ -29,7 +32,6 @@ class YOLODataset(BaseDataset): |
|
|
|
|
Returns: |
|
|
|
|
(torch.utils.data.Dataset): A PyTorch dataset object that can be used for training an object detection model. |
|
|
|
|
""" |
|
|
|
|
cache_version = '1.0.2' # dataset labels *.cache version, >= 1.0.0 for YOLOv8 |
|
|
|
|
|
|
|
|
|
def __init__(self, *args, data=None, use_segments=False, use_keypoints=False, **kwargs): |
|
|
|
|
self.use_segments = use_segments |
|
|
|
@ -87,15 +89,7 @@ class YOLODataset(BaseDataset): |
|
|
|
|
x['hash'] = get_hash(self.label_files + self.im_files) |
|
|
|
|
x['results'] = nf, nm, ne, nc, len(self.im_files) |
|
|
|
|
x['msgs'] = msgs # warnings |
|
|
|
|
x['version'] = self.cache_version # cache version |
|
|
|
|
if is_dir_writeable(path.parent): |
|
|
|
|
if path.exists(): |
|
|
|
|
path.unlink() # remove *.cache file if exists |
|
|
|
|
np.save(str(path), x) # save cache for next time |
|
|
|
|
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix |
|
|
|
|
LOGGER.info(f'{self.prefix}New cache created: {path}') |
|
|
|
|
else: |
|
|
|
|
LOGGER.warning(f'{self.prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') |
|
|
|
|
save_dataset_cache_file(self.prefix, path, x) |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
def get_labels(self): |
|
|
|
@ -103,11 +97,8 @@ class YOLODataset(BaseDataset): |
|
|
|
|
self.label_files = img2label_paths(self.im_files) |
|
|
|
|
cache_path = Path(self.label_files[0]).parent.with_suffix('.cache') |
|
|
|
|
try: |
|
|
|
|
import gc |
|
|
|
|
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 |
|
|
|
|
cache, exists = np.load(str(cache_path), allow_pickle=True).item(), True # load dict |
|
|
|
|
gc.enable() |
|
|
|
|
assert cache['version'] == self.cache_version # matches current version |
|
|
|
|
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file |
|
|
|
|
assert cache['version'] == DATASET_CACHE_VERSION # matches current version |
|
|
|
|
assert cache['hash'] == get_hash(self.label_files + self.im_files) # identical hash |
|
|
|
|
except (FileNotFoundError, AssertionError, AttributeError): |
|
|
|
|
cache, exists = self.cache_labels(cache_path), False # run cache ops |
|
|
|
@ -116,7 +107,7 @@ class YOLODataset(BaseDataset): |
|
|
|
|
nf, nm, ne, nc, n = cache.pop('results') # found, missing, empty, corrupt, total |
|
|
|
|
if exists and LOCAL_RANK in (-1, 0): |
|
|
|
|
d = f'Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt' |
|
|
|
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display cache results |
|
|
|
|
tqdm(None, desc=self.prefix + d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) # display results |
|
|
|
|
if cache['msgs']: |
|
|
|
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings |
|
|
|
|
if nf == 0: # number of labels found |
|
|
|
@ -216,7 +207,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): |
|
|
|
|
album_transforms (callable, optional): Albumentations transforms applied to the dataset if augment is True. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, root, args, augment=False, cache=False): |
|
|
|
|
def __init__(self, root, args, augment=False, cache=False, prefix=''): |
|
|
|
|
""" |
|
|
|
|
Initialize YOLO object with root, image size, augmentations, and cache settings. |
|
|
|
|
|
|
|
|
@ -229,8 +220,10 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): |
|
|
|
|
super().__init__(root=root) |
|
|
|
|
if augment and args.fraction < 1.0: # reduce training fraction |
|
|
|
|
self.samples = self.samples[:round(len(self.samples) * args.fraction)] |
|
|
|
|
self.prefix = colorstr(f'{prefix}: ') if prefix else '' |
|
|
|
|
self.cache_ram = cache is True or cache == 'ram' |
|
|
|
|
self.cache_disk = cache == 'disk' |
|
|
|
|
self.samples = self.verify_images() # filter out bad images |
|
|
|
|
self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im |
|
|
|
|
self.torch_transforms = classify_transforms(args.imgsz) |
|
|
|
|
self.album_transforms = classify_albumentations( |
|
|
|
@ -266,6 +259,67 @@ class ClassificationDataset(torchvision.datasets.ImageFolder): |
|
|
|
|
def __len__(self) -> int: |
|
|
|
|
return len(self.samples) |
|
|
|
|
|
|
|
|
|
def verify_images(self): |
|
|
|
|
"""Verify all images in dataset.""" |
|
|
|
|
desc = f'{self.prefix}Scanning {self.root}...' |
|
|
|
|
path = Path(self.root).with_suffix('.cache') # *.cache file path |
|
|
|
|
|
|
|
|
|
with contextlib.suppress(FileNotFoundError, AssertionError, AttributeError): |
|
|
|
|
cache = load_dataset_cache_file(path) # attempt to load a *.cache file |
|
|
|
|
assert cache['version'] == DATASET_CACHE_VERSION # matches current version |
|
|
|
|
assert cache['hash'] == get_hash([x[0] for x in self.samples]) # identical hash |
|
|
|
|
nf, nc, n, samples = cache.pop('results') # found, missing, empty, corrupt, total |
|
|
|
|
if LOCAL_RANK in (-1, 0): |
|
|
|
|
d = f'{desc} {nf} images, {nc} corrupt' |
|
|
|
|
tqdm(None, desc=d, total=n, initial=n, bar_format=TQDM_BAR_FORMAT) |
|
|
|
|
if cache['msgs']: |
|
|
|
|
LOGGER.info('\n'.join(cache['msgs'])) # display warnings |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
# Run scan if *.cache retrieval failed |
|
|
|
|
nf, nc, msgs, samples, x = 0, 0, [], [], {} |
|
|
|
|
with ThreadPool(NUM_THREADS) as pool: |
|
|
|
|
results = pool.imap(func=verify_image, iterable=zip([x[0] for x in self.samples], repeat(self.prefix))) |
|
|
|
|
pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT) |
|
|
|
|
for im_file, nf_f, nc_f, msg in pbar: |
|
|
|
|
if nf_f: |
|
|
|
|
samples.append((im_file, nf)) |
|
|
|
|
if msg: |
|
|
|
|
msgs.append(msg) |
|
|
|
|
nf += nf_f |
|
|
|
|
nc += nc_f |
|
|
|
|
pbar.desc = f'{desc} {nf} images, {nc} corrupt' |
|
|
|
|
pbar.close() |
|
|
|
|
if msgs: |
|
|
|
|
LOGGER.info('\n'.join(msgs)) |
|
|
|
|
x['hash'] = get_hash([x[0] for x in self.samples]) |
|
|
|
|
x['results'] = nf, nc, len(samples), samples |
|
|
|
|
x['msgs'] = msgs # warnings |
|
|
|
|
save_dataset_cache_file(self.prefix, path, x) |
|
|
|
|
return samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_dataset_cache_file(path): |
|
|
|
|
"""Load an Ultralytics *.cache dictionary from path.""" |
|
|
|
|
import gc |
|
|
|
|
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585 |
|
|
|
|
cache = np.load(str(path), allow_pickle=True).item() # load dict |
|
|
|
|
gc.enable() |
|
|
|
|
return cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_dataset_cache_file(prefix, path, x): |
|
|
|
|
"""Save an Ultralytics dataset *.cache dictionary x to path.""" |
|
|
|
|
x['version'] = DATASET_CACHE_VERSION # add cache version |
|
|
|
|
if is_dir_writeable(path.parent): |
|
|
|
|
if path.exists(): |
|
|
|
|
path.unlink() # remove *.cache file if exists |
|
|
|
|
np.save(str(path), x) # save cache for next time |
|
|
|
|
path.with_suffix('.cache.npy').rename(path) # remove .npy suffix |
|
|
|
|
LOGGER.info(f'{prefix}New cache created: {path}') |
|
|
|
|
else: |
|
|
|
|
LOGGER.warning(f'{prefix}WARNING ⚠️ Cache directory {path.parent} is not writeable, cache not saved.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: support semantic segmentation |
|
|
|
|
class SemanticDataset(BaseDataset): |
|
|
|
|