|
|
|
@ -4,13 +4,13 @@ from pathlib import Path |
|
|
|
|
|
|
|
|
|
import hydra |
|
|
|
|
import torch |
|
|
|
|
import torchvision |
|
|
|
|
|
|
|
|
|
from ultralytics.yolo import v8 |
|
|
|
|
from ultralytics.yolo.data import build_classification_dataloader |
|
|
|
|
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer |
|
|
|
|
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer |
|
|
|
|
from ultralytics.yolo.utils.downloads import download |
|
|
|
|
from ultralytics.yolo.utils.files import WorkingDirectory |
|
|
|
|
from ultralytics.yolo.utils.loggers import colorstr |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -30,8 +30,7 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
else: |
|
|
|
|
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' |
|
|
|
|
download(url, dir=data_dir.parent) |
|
|
|
|
# TODO: add colorstr |
|
|
|
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n" |
|
|
|
|
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n" |
|
|
|
|
self.console.info(s) |
|
|
|
|
train_set = data_dir / "train" |
|
|
|
|
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val |
|
|
|
@ -48,7 +47,7 @@ class ClassificationTrainer(BaseTrainer): |
|
|
|
|
return torch.nn.functional.cross_entropy(preds, targets) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0]) |
|
|
|
|
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.stem) |
|
|
|
|
def train(cfg): |
|
|
|
|
cfg.model = cfg.model or "resnet18" |
|
|
|
|
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist") |
|
|
|
@ -59,7 +58,7 @@ def train(cfg): |
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
""" |
|
|
|
|
CLI usage: |
|
|
|
|
python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1 |
|
|
|
|
python path/to/train.py epochs=10 project=PROJECT lr0=0.1 |
|
|
|
|
|
|
|
|
|
TODO: |
|
|
|
|
Direct cli support, i.e, yolov8 classify_train args.epochs 10 |
|
|
|
|