Fix Classify train from arbitrary CWD (#3692)

pull/3694/head
Glenn Jocher 1 year ago committed by GitHub
parent 15e9eac27b
commit 395cc47c53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 41
      ultralytics/yolo/data/utils.py

@ -268,28 +268,33 @@ def check_det_dataset(dataset, autodownload=True):
def check_cls_dataset(dataset: str, split=''):
"""
Check a classification dataset such as Imagenet.
Checks a classification dataset such as Imagenet.
This function takes a `dataset` name as input and returns a dictionary containing information about the dataset.
If the dataset is not found, it attempts to download the dataset from the internet and save it locally.
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
Args:
dataset (str): Name of the dataset.
split (str, optional): Dataset split, either 'val', 'test', or ''. Defaults to ''.
dataset (str): The name of the dataset.
split (str, optional): The split of the dataset. Either 'val', 'test', or ''. Defaults to ''.
Returns:
data (dict): A dictionary containing the following keys and values:
'train': Path object for the directory containing the training set of the dataset
'val': Path object for the directory containing the validation set of the dataset
'test': Path object for the directory containing the test set of the dataset
'nc': Number of classes in the dataset
'names': List of class names in the dataset
dict: A dictionary containing the following keys:
- 'train' (Path): The directory path containing the training set of the dataset.
- 'val' (Path): The directory path containing the validation set of the dataset.
- 'test' (Path): The directory path containing the test set of the dataset.
- 'nc' (int): The number of classes in the dataset.
- 'names' (dict): A dictionary of class names in the dataset.
Raises:
FileNotFoundError: If the specified dataset is not found and cannot be downloaded.
"""
data_dir = (DATASETS_DIR / dataset).resolve()
dataset = Path(dataset)
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
if not data_dir.is_dir():
LOGGER.info(f'\nDataset not found ⚠, missing path {data_dir}, attempting download...')
t = time.time()
if dataset == 'imagenet':
if str(dataset) == 'imagenet':
subprocess.run(f"bash {ROOT / 'yolo/data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
@ -312,12 +317,12 @@ def check_cls_dataset(dataset: str, split=''):
class HUBDatasetStats():
"""
Class for generating HUB dataset JSON and `-hub` dataset directory
A class for generating HUB dataset JSON and `-hub` dataset directory.
Arguments
path: Path to data.yaml or data.zip (with data.yaml inside data.zip)
task: Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.
autodownload: Attempt to download dataset if not found locally
Args:
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip). Default is 'coco128.yaml'.
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'. Default is 'detect'.
autodownload (bool): Attempt to download dataset if not found locally. Default is False.
Usage
from ultralytics.yolo.data.utils import HUBDatasetStats

Loading…
Cancel
Save