`ultralytics 8.0.161` fix Classify dataset scanning bug (#4515)

pull/4517/head v8.0.161
Glenn Jocher 1 year ago committed by GitHub
parent 3c40e7a9fc
commit 67eeb0468d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      ultralytics/__init__.py
  2. 8
      ultralytics/data/dataset.py
  3. 23
      ultralytics/data/utils.py
  4. 14
      ultralytics/utils/downloads.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = '8.0.160' __version__ = '8.0.161'
from ultralytics.models import RTDETR, SAM, YOLO from ultralytics.models import RTDETR, SAM, YOLO
from ultralytics.models.fastsam import FastSAM from ultralytics.models.fastsam import FastSAM

@ -17,7 +17,7 @@ from .base import BaseDataset
from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label from .utils import HELP_URL, LOGGER, get_hash, img2label_paths, verify_image, verify_image_label
# Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8 # Ultralytics dataset *.cache version, >= 1.0.0 for YOLOv8
DATASET_CACHE_VERSION = '1.0.2' DATASET_CACHE_VERSION = '1.0.3'
class YOLODataset(BaseDataset): class YOLODataset(BaseDataset):
@ -279,11 +279,11 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
# Run scan if *.cache retrieval failed # Run scan if *.cache retrieval failed
nf, nc, msgs, samples, x = 0, 0, [], [], {} nf, nc, msgs, samples, x = 0, 0, [], [], {}
with ThreadPool(NUM_THREADS) as pool: with ThreadPool(NUM_THREADS) as pool:
results = pool.imap(func=verify_image, iterable=zip([x[0] for x in self.samples], repeat(self.prefix))) results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT) pbar = tqdm(results, desc=desc, total=len(self.samples), bar_format=TQDM_BAR_FORMAT)
for im_file, nf_f, nc_f, msg in pbar: for sample, nf_f, nc_f, msg in pbar:
if nf_f: if nf_f:
samples.append((im_file, nf)) samples.append(sample)
if msg: if msg:
msgs.append(msg) msgs.append(msg)
nf += nf_f nf += nf_f

@ -59,7 +59,7 @@ def exif_size(img: Image.Image):
def verify_image(args): def verify_image(args):
"""Verify one image.""" """Verify one image."""
im_file, prefix = args (im_file, cls), prefix = args
# Number (found, corrupt), message # Number (found, corrupt), message
nf, nc, msg = 0, 0, '' nf, nc, msg = 0, 0, ''
try: try:
@ -79,7 +79,7 @@ def verify_image(args):
except Exception as e: except Exception as e:
nc = 1 nc = 1
msg = f'{prefix}WARNING ⚠ {im_file}: ignoring corrupt image/label: {e}' msg = f'{prefix}WARNING ⚠ {im_file}: ignoring corrupt image/label: {e}'
return im_file, nf, nc, msg return (im_file, cls), nf, nc, msg
def verify_image_label(args): def verify_image_label(args):
@ -321,7 +321,7 @@ def check_cls_dataset(dataset: str, split=''):
dataset = Path(dataset) dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve() data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir(): if not data_dir.is_dir():
LOGGER.info(f'\nDataset not found ⚠, missing path {data_dir}, attempting download...') LOGGER.warning(f'\nDataset not found ⚠, missing path {data_dir}, attempting download...')
t = time.time() t = time.time()
if str(dataset) == 'imagenet': if str(dataset) == 'imagenet':
subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True) subprocess.run(f"bash {ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
@ -335,9 +335,9 @@ def check_cls_dataset(dataset: str, split=''):
data_dir / 'validation').exists() else None # data/test or data/val data_dir / 'validation').exists() else None # data/test or data/val
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test
if split == 'val' and not val_set: if split == 'val' and not val_set:
LOGGER.info("WARNING ⚠ Dataset 'split=val' not found, using 'split=test' instead.") LOGGER.warning("WARNING ⚠ Dataset 'split=val' not found, using 'split=test' instead.")
elif split == 'test' and not test_set: elif split == 'test' and not test_set:
LOGGER.info("WARNING ⚠ Dataset 'split=test' not found, using 'split=val' instead.") LOGGER.warning("WARNING ⚠ Dataset 'split=test' not found, using 'split=val' instead.")
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list
@ -345,13 +345,22 @@ def check_cls_dataset(dataset: str, split=''):
# Print to console # Print to console
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items(): for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items():
prefix = f'{colorstr(k)} {v}...'
if v is None: if v is None:
LOGGER.info(f'{colorstr(k)}: {v}') LOGGER.info(prefix)
else: else:
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS] files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS]
nf = len(files) # number of files nf = len(files) # number of files
nd = len({file.parent for file in files}) # number of directories nd = len({file.parent for file in files}) # number of directories
LOGGER.info(f'{colorstr(k)}: {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space if nf == 0:
if k == 'train':
raise FileNotFoundError(emojis(f"{dataset} '{k}:' no training images found ❌ "))
else:
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: WARNING ⚠ no images found')
elif nd != nc:
LOGGER.warning(f'{prefix} found {nf} images in {nd} classes: ERROR ❌ requires {nc} classes, not {nd}')
else:
LOGGER.info(f'{prefix} found {nf} images in {nd} classes ✅ ')
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names} return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names}

@ -39,16 +39,17 @@ def is_url(url, check=True):
return False return False
def delete_dsstore(path): def delete_dsstore(path, files_to_delete=('.DS_Store', '__MACOSX')):
""" """
Deletes all ".DS_store" files under a specified directory. Deletes all ".DS_store" files under a specified directory.
Args: Args:
path (str, optional): The directory path where the ".DS_store" files should be deleted. path (str, optional): The directory path where the ".DS_store" files should be deleted.
files_to_delete (tuple): The files to be deleted.
Example: Example:
```python ```python
from ultralytics.data.utils import delete_dsstore from ultralytics.utils.downloads import delete_dsstore
delete_dsstore('path/to/dir') delete_dsstore('path/to/dir')
``` ```
@ -58,10 +59,11 @@ def delete_dsstore(path):
are hidden system files and can cause issues when transferring files between different operating systems. are hidden system files and can cause issues when transferring files between different operating systems.
""" """
# Delete Apple .DS_store files # Delete Apple .DS_store files
files = list(Path(path).rglob('.DS_store')) for file in files_to_delete:
LOGGER.info(f'Deleting *.DS_store files: {files}') matches = list(Path(path).rglob(file))
for f in files: LOGGER.info(f'Deleting {file} files: {matches}')
f.unlink() for f in matches:
f.unlink()
def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True): def zip_directory(directory, compress=True, exclude=('.DS_Store', '__MACOSX'), progress=True):

Loading…
Cancel
Save