You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
393 lines
16 KiB
393 lines
16 KiB
""" |
|
Simple training loop; Boilerplate that could apply to any arbitrary neural network, |
|
""" |
|
|
|
import os |
|
import time |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Dict, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import torch.nn as nn |
|
from omegaconf import DictConfig, OmegaConf |
|
from torch.cuda import amp |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.optim import lr_scheduler |
|
from tqdm import tqdm |
|
|
|
import ultralytics.yolo.utils as utils |
|
import ultralytics.yolo.utils.callbacks as callbacks |
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml |
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT |
|
from ultralytics.yolo.utils.checks import print_args |
|
from ultralytics.yolo.utils.files import increment_path, save_yaml |
|
from ultralytics.yolo.utils.modeling import get_model |
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds, one_cycle |
|
|
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" |
|
RANK = int(os.getenv('RANK', -1)) |
|
|
|
|
|
class BaseTrainer: |
|
|
|
def __init__(self, config=DEFAULT_CONFIG, overrides={}): |
|
self.args = self._get_config(config, overrides) |
|
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) |
|
|
|
self.console = LOGGER |
|
self.validator = None |
|
self.model = None |
|
self.callbacks = defaultdict(list) |
|
self.save_dir = increment_path(Path(self.args.project) / self.args.name, exist_ok=self.args.exist_ok) |
|
self.wdir = self.save_dir / 'weights' # weights dir |
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir |
|
self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt' # checkpoint paths |
|
print_args(dict(self.args)) |
|
|
|
# Save run settings |
|
save_yaml(self.save_dir / 'args.yaml', OmegaConf.to_container(self.args, resolve=True)) |
|
|
|
# device |
|
self.device = utils.torch_utils.select_device(self.args.device, self.args.batch_size) |
|
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu') |
|
|
|
# Model and Dataloaders. |
|
self.data = self.args.data |
|
if self.data.endswith(".yaml"): |
|
self.data = check_dataset_yaml(self.data) |
|
else: |
|
self.data = check_dataset(self.data) |
|
self.trainset, self.testset = self.get_dataset(self.data) |
|
if self.args.model: |
|
self.model = self.get_model(self.args.model) |
|
self.ema = None |
|
|
|
# Optimization utils init |
|
self.lf = None |
|
self.scheduler = None |
|
|
|
# epoch level metrics |
|
self.metrics = {} # handle metrics returned by validator |
|
self.best_fitness = None |
|
self.fitness = None |
|
self.loss = None |
|
|
|
for callback, func in callbacks.default_callbacks.items(): |
|
self.add_callback(callback, func) |
|
callbacks.add_integration_callbacks(self) |
|
|
|
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): |
|
""" |
|
Accepts yaml file name or DictConfig containing experiment configuration. |
|
Returns training args namespace |
|
:param config: Optional file name or DictConfig object |
|
""" |
|
if isinstance(config, (str, Path)): |
|
config = OmegaConf.load(config) |
|
elif isinstance(config, Dict): |
|
config = OmegaConf.create(config) |
|
|
|
# override |
|
if isinstance(overrides, str): |
|
overrides = OmegaConf.load(overrides) |
|
elif isinstance(overrides, Dict): |
|
overrides = OmegaConf.create(overrides) |
|
|
|
return OmegaConf.merge(config, overrides) |
|
|
|
def add_callback(self, onevent: str, callback): |
|
""" |
|
appends the given callback |
|
""" |
|
self.callbacks[onevent].append(callback) |
|
|
|
def set_callback(self, onevent: str, callback): |
|
""" |
|
overrides the existing callbacks with the given callback |
|
""" |
|
self.callbacks[onevent] = [callback] |
|
|
|
def trigger_callbacks(self, onevent: str): |
|
for callback in self.callbacks.get(onevent, []): |
|
callback(self) |
|
|
|
def train(self): |
|
world_size = torch.cuda.device_count() |
|
if world_size > 1: |
|
mp.spawn(self._do_train, args=(world_size,), nprocs=world_size, join=True) |
|
else: |
|
self._do_train() |
|
|
|
def _setup_ddp(self, rank, world_size): |
|
os.environ['MASTER_ADDR'] = 'localhost' |
|
os.environ['MASTER_PORT'] = '9020' |
|
torch.cuda.set_device(rank) |
|
self.device = torch.device('cuda', rank) |
|
print(f"RANK - WORLD_SIZE - DEVICE: {rank} - {world_size} - {self.device} ") |
|
|
|
dist.init_process_group("nccl" if dist.is_nccl_available() else "gloo", rank=rank, world_size=world_size) |
|
self.model = self.model.to(self.device) |
|
self.model = DDP(self.model, device_ids=[rank]) |
|
self.args.batch_size = self.args.batch_size // world_size |
|
|
|
def _setup_train(self, rank): |
|
""" |
|
Builds dataloaders and optimizer on correct rank process |
|
""" |
|
# Optimizer |
|
self.set_model_attributes() |
|
accumulate = max(round(self.args.nbs / self.args.batch_size), 1) # accumulate loss before optimizing |
|
self.args.weight_decay *= self.args.batch_size * accumulate / self.args.nbs # scale weight_decay |
|
self.optimizer = build_optimizer(model=self.model, |
|
name=self.args.optimizer, |
|
lr=self.args.lr0, |
|
momentum=self.args.momentum, |
|
decay=self.args.weight_decay) |
|
# Scheduler |
|
if self.args.cos_lr: |
|
self.lf = one_cycle(1, self.args.lrf, self.args.epochs) # cosine 1->hyp['lrf'] |
|
else: |
|
self.lf = lambda x: (1 - x / self.args.epochs) * (1.0 - self.args.lrf + self.args.lrf) # linear |
|
self.scheduler = lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) |
|
|
|
# dataloaders |
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.args.batch_size, rank=rank) |
|
if rank in {0, -1}: |
|
print(" Creating testloader rank :", rank) |
|
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1) |
|
self.validator = self.get_validator() |
|
print("created testloader :", rank) |
|
self.ema = ModelEMA(self.model) |
|
|
|
def _do_train(self, rank=-1, world_size=1): |
|
if world_size > 1: |
|
self._setup_ddp(rank, world_size) |
|
else: |
|
self.model = self.model.to(self.device) |
|
|
|
self.trigger_callbacks("before_train") |
|
self._setup_train(rank) |
|
|
|
self.epoch = 0 |
|
self.epoch_time = None |
|
self.epoch_time_start = time.time() |
|
self.train_time_start = time.time() |
|
nb = len(self.train_loader) # number of batches |
|
nw = max(round(self.args.warmup_epochs * nb), 100) # number of warmup iterations |
|
last_opt_step = -1 |
|
for epoch in range(self.args.epochs): |
|
self.trigger_callbacks("on_epoch_start") |
|
self.model.train() |
|
pbar = enumerate(self.train_loader) |
|
if rank in {-1, 0}: |
|
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT) |
|
self.tloss = None |
|
for i, batch in pbar: |
|
self.trigger_callbacks("on_batch_start") |
|
# forward |
|
batch = self.preprocess_batch(batch) |
|
|
|
# warmup |
|
ni = i + nb * epoch |
|
if ni <= nw: |
|
xi = [0, nw] # x interp |
|
accumulate = max(1, np.interp(ni, xi, [1, self.args.nbs / self.args.batch_size]).round()) |
|
for j, x in enumerate(self.optimizer.param_groups): |
|
# bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0 |
|
x['lr'] = np.interp( |
|
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x['initial_lr'] * self.lf(epoch)]) |
|
if 'momentum' in x: |
|
x['momentum'] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum]) |
|
|
|
preds = self.model(batch["img"]) |
|
self.loss, self.loss_items = self.criterion(preds, batch) |
|
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ |
|
else self.loss_items |
|
|
|
# backward |
|
self.model.zero_grad(set_to_none=True) |
|
self.scaler.scale(self.loss).backward() |
|
|
|
# optimize |
|
if ni - last_opt_step >= accumulate: |
|
self.optimizer_step() |
|
last_opt_step = ni |
|
|
|
# log |
|
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) |
|
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 |
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) |
|
if rank in {-1, 0}: |
|
pbar.set_description( |
|
(" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem, |
|
*losses, batch["img"].shape[-1])) |
|
self.trigger_callbacks('on_batch_end') |
|
|
|
if rank in [-1, 0]: |
|
# validation |
|
self.trigger_callbacks('on_val_start') |
|
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) |
|
self.metrics, self.fitness = self.validate() |
|
self.trigger_callbacks('on_val_end') |
|
|
|
# save model |
|
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs): |
|
self.save_model() |
|
self.trigger_callbacks('on_model_save') |
|
|
|
self.epoch += 1 |
|
tnow = time.time() |
|
self.epoch_time = tnow - self.epoch_time_start |
|
self.epoch_time_start = tnow |
|
|
|
# TODO: termination condition |
|
|
|
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") |
|
self.trigger_callbacks('on_train_end') |
|
dist.destroy_process_group() if world_size != 1 else None |
|
|
|
def save_model(self): |
|
ckpt = { |
|
'epoch': self.epoch, |
|
'best_fitness': self.best_fitness, |
|
'model': deepcopy(de_parallel(self.model)).half(), |
|
'ema': deepcopy(self.ema.ema).half(), |
|
'updates': self.ema.updates, |
|
'optimizer': self.optimizer.state_dict(), |
|
'train_args': self.args, |
|
'date': datetime.now().isoformat()} |
|
|
|
# Save last, best and delete |
|
torch.save(ckpt, self.last) |
|
if self.best_fitness == self.fitness: |
|
torch.save(ckpt, self.best) |
|
del ckpt |
|
|
|
def get_dataset(self, data): |
|
""" |
|
Get train, val path from data dict if it exists. Returns None if data format is not recognized |
|
""" |
|
return data["train"], data["val"] |
|
|
|
def get_model(self, model: Union[str, Path]): |
|
""" |
|
load/create/download model for any task |
|
""" |
|
pretrained = not str(model).endswith(".yaml") |
|
return self.load_model(model_cfg=None if pretrained else model, |
|
weights=get_model(model) if pretrained else None, |
|
data=self.data) # model |
|
|
|
def optimizer_step(self): |
|
self.scaler.unscale_(self.optimizer) # unscale gradients |
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients |
|
self.scaler.step(self.optimizer) |
|
self.scaler.update() |
|
self.optimizer.zero_grad() |
|
if self.ema: |
|
self.ema.update(self.model) |
|
|
|
def preprocess_batch(self, batch): |
|
""" |
|
Allows custom preprocessing model inputs and ground truths depending on task type |
|
""" |
|
return batch |
|
|
|
def validate(self): |
|
""" |
|
Runs validation on test set using self.validator. |
|
# TODO: discuss validator class. Enforce that a validator metrics dict should contain |
|
"fitness" metric. |
|
""" |
|
metrics = self.validator(self) |
|
fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found |
|
if not self.best_fitness or self.best_fitness < fitness: |
|
self.best_fitness = self.fitness |
|
return metrics, fitness |
|
|
|
def log(self, text, rank=-1): |
|
""" |
|
Logs the given text to given ranks process if provided, otherwise logs to all ranks |
|
:param text: text to log |
|
:param rank: List[Int] |
|
|
|
""" |
|
if rank in {-1, 0}: |
|
self.console.info(text) |
|
|
|
def load_model(self, model_cfg, weights, data): |
|
raise NotImplementedError("This task trainer doesn't support loading cfg files") |
|
|
|
def get_validator(self): |
|
raise NotImplementedError("get_validator function not implemented in trainer") |
|
|
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0): |
|
""" |
|
Returns dataloader derived from torch.data.Dataloader |
|
""" |
|
raise NotImplementedError("get_dataloader function not implemented in trainer") |
|
|
|
def criterion(self, preds, batch): |
|
""" |
|
Returns loss and individual loss items as Tensor |
|
""" |
|
raise NotImplementedError("criterion function not implemented in trainer") |
|
|
|
def label_loss_items(self, loss_items): |
|
""" |
|
Returns a loss dict with labelled training loss items tensor |
|
""" |
|
# Not needed for classification but necessary for segmentation & detection |
|
return {"loss": loss_items} |
|
|
|
def set_model_attributes(self): |
|
""" |
|
To set or update model parameters before training. |
|
""" |
|
pass |
|
|
|
def build_targets(self, preds, targets): |
|
pass |
|
|
|
|
|
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): |
|
# TODO: 1. docstring with example? 2. Move this inside Trainer? or utils? |
|
# YOLOv5 3-param group optimizer: 0) weights with decay, 1) weights no decay, 2) biases no decay |
|
g = [], [], [] # optimizer parameter groups |
|
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d() |
|
for v in model.modules(): |
|
if hasattr(v, 'bias') and isinstance(v.bias, nn.Parameter): # bias (no decay) |
|
g[2].append(v.bias) |
|
if isinstance(v, bn): # weight (no decay) |
|
g[1].append(v.weight) |
|
elif hasattr(v, 'weight') and isinstance(v.weight, nn.Parameter): # weight (with decay) |
|
g[0].append(v.weight) |
|
|
|
if name == 'Adam': |
|
optimizer = torch.optim.Adam(g[2], lr=lr, betas=(momentum, 0.999)) # adjust beta1 to momentum |
|
elif name == 'AdamW': |
|
optimizer = torch.optim.AdamW(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0) |
|
elif name == 'RMSProp': |
|
optimizer = torch.optim.RMSprop(g[2], lr=lr, momentum=momentum) |
|
elif name == 'SGD': |
|
optimizer = torch.optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True) |
|
else: |
|
raise NotImplementedError(f'Optimizer {name} not implemented.') |
|
|
|
optimizer.add_param_group({'params': g[0], 'weight_decay': decay}) # add g0 with weight_decay |
|
optimizer.add_param_group({'params': g[1], 'weight_decay': 0.0}) # add g1 (BatchNorm2d weights) |
|
LOGGER.info(f"optimizer: {type(optimizer).__name__}(lr={lr}) with parameter groups " |
|
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias") |
|
return optimizer |
|
|
|
|
|
# Dummy validator |
|
def val(trainer: BaseTrainer): |
|
trainer.console.info("validating") |
|
return {"metric_1": 0.1, "metric_2": 0.2, "fitness": 1}
|
|
|