Classify training cleanup (#33)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/34/head
Glenn Jocher 2 years ago committed by GitHub
parent 2e9b18ce4e
commit 6fe8bead35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      ultralytics/yolo/engine/trainer.py
  2. 25
      ultralytics/yolo/utils/configs/default.yml
  3. 17
      ultralytics/yolo/utils/modeling/__init__.py
  4. 11
      ultralytics/yolo/v8/classify/train.py

@ -24,13 +24,12 @@ from ultralytics.yolo.utils import LOGGER, ROOT
from ultralytics.yolo.utils.files import increment_path, save_yaml from ultralytics.yolo.utils.files import increment_path, save_yaml
from ultralytics.yolo.utils.modeling import get_model from ultralytics.yolo.utils.modeling import get_model
CONFIG_PATH_ABS = ROOT / "yolo/utils/configs" DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yml"
DEFAULT_CONFIG = "defaults.yaml"
class BaseTrainer: class BaseTrainer:
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG, overrides={}): def __init__(self, config=DEFAULT_CONFIG, overrides={}):
self.console = LOGGER self.console = LOGGER
self.args = self._get_config(config, overrides) self.args = self._get_config(config, overrides)
self.validator = None self.validator = None

@ -1,25 +1,27 @@
model: null # YOLO 🚀 by Ultralytics, GPL-3.0 license
data: null # Default training settings and hyperparameters for medium-augmentation COCO training
# Training options
# Train settings -------------------------------------------------------------------------------------------------------
model: null # i.e. yolov5s.pt
data: null # i.e. coco128.yaml
epochs: 300 epochs: 300
batch_size: 16 batch_size: 16
img_size: 640 img_size: 640
nosave: False nosave: False
cache: False # True/ram for ram, or disc cache: False # True/ram, disk or False
device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu device: '' # cuda device, i.e. 0 or 0,1,2,3 or cpu
workers: 8 workers: 8
project: "ultralytics-yolo" project: 'runs'
name: "exp" # TODO: make this informative, maybe exp{#number}_{datetime} ? name: 'exp'
exist_ok: False exist_ok: False
pretrained: False pretrained: False
optimizer: "Adam" # choices=['SGD', 'Adam', 'AdamW', 'RMSProp'] optimizer: 'SGD' # choices=['SGD', 'Adam', 'AdamW', 'RMSProp']
verbose: False verbose: False
seed: 0 seed: 0
local_rank: -1 local_rank: -1
#-----------------------------------#
# Hyper-parameters # Hyperparameters ------------------------------------------------------------------------------------------------------
lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3) lr0: 0.001 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf) lrf: 0.01 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1 momentum: 0.937 # SGD momentum/Adam beta1
@ -50,9 +52,8 @@ mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability) mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability) copy_paste: 0.0 # segment copy-paste (probability)
# Hydra configs ------------------------------------- # Hydra configs --------------------------------------------------------------------------------------------------------
# to disable hydra directory creation
hydra: hydra:
output_subdir: null output_subdir: null # disable hydra directory creation
run: run:
dir: . dir: .

@ -107,18 +107,17 @@ def parse_model(d, ch): # model_dict, input_channels(3)
return nn.Sequential(*layers), sorted(save) return nn.Sequential(*layers), sorted(save)
def get_model(model: str): def get_model(model='s.pt', pretrained=True):
# Load a YOLO model locally, from torchvision, or from Ultralytics assets
if model.endswith(".pt"): if model.endswith(".pt"):
model = model.split(".")[0] model = model.split(".")[0]
if Path(model + ".pt").is_file(): if Path(f"{model}.pt").is_file(): # local file
trained_model = torch.load(model + ".pt", map_location='cpu') return torch.load(f"{model}.pt", map_location='cpu')
elif model in torchvision.models.__dict__: # try torch hub classifier models elif model in torchvision.models.__dict__: # TorchVision models i.e. resnet50, efficientnet_b0
trained_model = torch.hub.load("pytorch/vision", model, pretrained=True) return torchvision.models.__dict__[model](weights='IMAGENET1K_V1' if pretrained else None)
else: else: # Ultralytics assets
model_ckpt = attempt_download(model + ".pt") # try ultralytics assets return torch.load(attempt_download(f"{model}.pt"), map_location='cpu')
trained_model = torch.load(model_ckpt, map_location='cpu')
return trained_model
def yaml_load(file='data.yaml'): def yaml_load(file='data.yaml'):

@ -4,13 +4,13 @@ from pathlib import Path
import hydra import hydra
import torch import torch
import torchvision
from ultralytics.yolo import v8 from ultralytics.yolo import v8
from ultralytics.yolo.data import build_classification_dataloader from ultralytics.yolo.data import build_classification_dataloader
from ultralytics.yolo.engine.trainer import CONFIG_PATH_ABS, DEFAULT_CONFIG, BaseTrainer from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
from ultralytics.yolo.utils.downloads import download from ultralytics.yolo.utils.downloads import download
from ultralytics.yolo.utils.files import WorkingDirectory from ultralytics.yolo.utils.files import WorkingDirectory
from ultralytics.yolo.utils.loggers import colorstr
from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first from ultralytics.yolo.utils.torch_utils import LOCAL_RANK, torch_distributed_zero_first
@ -30,8 +30,7 @@ class ClassificationTrainer(BaseTrainer):
else: else:
url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip' url = f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{dataset}.zip'
download(url, dir=data_dir.parent) download(url, dir=data_dir.parent)
# TODO: add colorstr s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n"
s = f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {'bold', data_dir}\n"
self.console.info(s) self.console.info(s)
train_set = data_dir / "train" train_set = data_dir / "train"
test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val test_set = data_dir / 'test' if (data_dir / 'test').exists() else data_dir / 'val' # data/test or data/val
@ -48,7 +47,7 @@ class ClassificationTrainer(BaseTrainer):
return torch.nn.functional.cross_entropy(preds, targets) return torch.nn.functional.cross_entropy(preds, targets)
@hydra.main(version_base=None, config_path=CONFIG_PATH_ABS, config_name=str(DEFAULT_CONFIG).split(".")[0]) @hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.stem)
def train(cfg): def train(cfg):
cfg.model = cfg.model or "resnet18" cfg.model = cfg.model or "resnet18"
cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist") cfg.data = cfg.data or "imagenette160" # or yolo.ClassificationDataset("mnist")
@ -59,7 +58,7 @@ def train(cfg):
if __name__ == "__main__": if __name__ == "__main__":
""" """
CLI usage: CLI usage:
python ../path/to/train.py args.epochs=10 args.project="name" hyps.lr0=0.1 python path/to/train.py epochs=10 project=PROJECT lr0=0.1
TODO: TODO:
Direct cli support, i.e, yolov8 classify_train args.epochs 10 Direct cli support, i.e, yolov8 classify_train args.epochs 10

Loading…
Cancel
Save