You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

68 lines
2.7 KiB

import subprocess
import time
from pathlib import Path
import hydra
import torch
import torchvision
from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import WorkingDirectory
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
# BaseTrainer python usage
class ClassificationTrainer(BaseTrainer):
def get_dataset(self, dataset):
# temporary solution. Replace with new ultralytics.yolo.ClassificationDataset module
data = Path("datasets") / dataset
with torch_distributed_zero_first(LOCAL_RANK), WorkingDirectory(Path.cwd()):
data_dir = data if data.is_dir() else (Path.cwd() / data)
if not data_dir.is_dir():
self.console.info(f'\nDataset not found ⚠, missing path {data_dir}, attempting download...')
t = time.time()
if str(data) == 'imagenet':
subprocess.run(f"bash {v8.ROOT / 'data/scripts/get_imagenet.sh'}", shell=True, check=True)
else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent)
# TODO: add colorstr
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
self.console.info(s)
train_set = data_dir / "train"
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
return train_set, test_set
def get_dataloader(self, dataset_path, batch_size=None, rank=0):
return build_classification_dataloader(path=dataset_path, batch_size=self.args.batch_size, rank=rank)
def get_validator(self):
return v8.classify.ClassificationValidator(self.test_loader, self.device, logger=self.console)
def criterion(self, preds, targets):
return torch.nn.functional.cross_entropy(preds, targets)
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0])
def train(cfg):
cfg.model = cfg.model or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
trainer = ClassificationTrainer(cfg)
trainer.train()
if __name__ == "__main__":
"""
CLI usage:
python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1
TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10
"""
train()