|
|
|
@ -304,7 +304,8 @@ def check_cls_dataset(dataset: str, split=''): |
|
|
|
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" |
|
|
|
|
LOGGER.info(s) |
|
|
|
|
train_set = data_dir / 'train' |
|
|
|
|
val_set = data_dir / 'val' if (data_dir / 'val').exists() else None # data/test or data/val |
|
|
|
|
val_set = data_dir / 'val' if (data_dir / 'val').exists() else data_dir / 'validation' if ( |
|
|
|
|
data_dir / 'validation').exists() else None # data/test or data/val |
|
|
|
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else None # data/val or data/test |
|
|
|
|
if split == 'val' and not val_set: |
|
|
|
|
LOGGER.info("WARNING ⚠️ Dataset 'split=val' not found, using 'split=test' instead.") |
|
|
|
@ -314,6 +315,17 @@ def check_cls_dataset(dataset: str, split=''): |
|
|
|
|
nc = len([x for x in (data_dir / 'train').glob('*') if x.is_dir()]) # number of classes |
|
|
|
|
names = [x.name for x in (data_dir / 'train').iterdir() if x.is_dir()] # class names list |
|
|
|
|
names = dict(enumerate(sorted(names))) |
|
|
|
|
|
|
|
|
|
# Print to console |
|
|
|
|
for k, v in {'train': train_set, 'val': val_set, 'test': test_set}.items(): |
|
|
|
|
if v is None: |
|
|
|
|
LOGGER.info(colorstr(k) + f': {v}') |
|
|
|
|
else: |
|
|
|
|
files = [path for path in v.rglob('*.*') if path.suffix[1:].lower() in IMG_FORMATS] |
|
|
|
|
nf = len(files) # number of files |
|
|
|
|
nd = len({file.parent for file in files}) # number of directories |
|
|
|
|
LOGGER.info(colorstr(k) + f': {v}... found {nf} images in {nd} classes ✅ ') # keep trailing space |
|
|
|
|
|
|
|
|
|
return {'train': train_set, 'val': val_set or test_set, 'test': test_set or val_set, 'nc': nc, 'names': names} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|