|
|
|
@ -1,10 +1,6 @@ |
|
|
|
|
""" |
|
|
|
|
Simple training loop; Boilerplate that could apply to any arbitrary neural network, |
|
|
|
|
""" |
|
|
|
|
# TODOs |
|
|
|
|
# 1. finish _set_model_attributes |
|
|
|
|
# 2. allow num_class update for both pretrained and csv_loaded models |
|
|
|
|
# 3. save |
|
|
|
|
|
|
|
|
|
import os |
|
|
|
|
import time |
|
|
|
@ -24,7 +20,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
import ultralytics.yolo.utils as utils |
|
|
|
|
import ultralytics.yolo.utils.loggers as loggers |
|
|
|
|
import ultralytics.yolo.utils.callbacks as callbacks |
|
|
|
|
from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml |
|
|
|
|
from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT |
|
|
|
|
from ultralytics.yolo.utils.checks import print_args |
|
|
|
@ -73,8 +69,9 @@ class BaseTrainer: |
|
|
|
|
self.fitness = None |
|
|
|
|
self.loss = None |
|
|
|
|
|
|
|
|
|
for callback, func in loggers.default_callbacks.items(): |
|
|
|
|
for callback, func in callbacks.default_callbacks.items(): |
|
|
|
|
self.add_callback(callback, func) |
|
|
|
|
callbacks.add_integration_callbacks(self) |
|
|
|
|
|
|
|
|
|
def _get_config(self, config: Union[str, DictConfig], overrides: Union[str, Dict] = {}): |
|
|
|
|
""" |
|
|
|
@ -146,7 +143,6 @@ class BaseTrainer: |
|
|
|
|
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=-1) |
|
|
|
|
self.validator = self.get_validator() |
|
|
|
|
print("created testloader :", rank) |
|
|
|
|
self.console.info(self.progress_string()) |
|
|
|
|
self.ema = ModelEMA(self.model) |
|
|
|
|
|
|
|
|
|
def _do_train(self, rank=-1, world_size=1): |
|
|
|
@ -155,7 +151,7 @@ class BaseTrainer: |
|
|
|
|
else: |
|
|
|
|
self.model = self.model.to(self.device) |
|
|
|
|
|
|
|
|
|
# callback hook. before_train |
|
|
|
|
self.trigger_callbacks("before_train") |
|
|
|
|
self._setup_train(rank) |
|
|
|
|
|
|
|
|
|
self.epoch = 1 |
|
|
|
@ -163,22 +159,22 @@ class BaseTrainer: |
|
|
|
|
self.epoch_time_start = time.time() |
|
|
|
|
self.train_time_start = time.time() |
|
|
|
|
for epoch in range(self.args.epochs): |
|
|
|
|
# callback hook. on_epoch_start |
|
|
|
|
self.trigger_callbacks("on_epoch_start") |
|
|
|
|
self.model.train() |
|
|
|
|
pbar = enumerate(self.train_loader) |
|
|
|
|
if rank in {-1, 0}: |
|
|
|
|
pbar = tqdm(enumerate(self.train_loader), total=len(self.train_loader), bar_format=TQDM_BAR_FORMAT) |
|
|
|
|
tloss = None |
|
|
|
|
self.tloss = None |
|
|
|
|
for i, batch in pbar: |
|
|
|
|
# img, label (classification)/ img, targets, paths, _, masks(detection) |
|
|
|
|
# callback hook. on_batch_start |
|
|
|
|
self.trigger_callbacks("on_batch_start") |
|
|
|
|
# forward |
|
|
|
|
batch = self.preprocess_batch(batch) |
|
|
|
|
|
|
|
|
|
# TODO: warmup, multiscale |
|
|
|
|
preds = self.model(batch["img"]) |
|
|
|
|
self.loss, self.loss_items = self.criterion(preds, batch) |
|
|
|
|
tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items |
|
|
|
|
self.tloss = (self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None \ |
|
|
|
|
else self.loss_items |
|
|
|
|
|
|
|
|
|
# backward |
|
|
|
|
self.model.zero_grad(set_to_none=True) |
|
|
|
@ -186,28 +182,28 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# optimize |
|
|
|
|
self.optimizer_step() |
|
|
|
|
self.trigger_callbacks('on_batch_end') |
|
|
|
|
|
|
|
|
|
# log |
|
|
|
|
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB) |
|
|
|
|
loss_len = tloss.shape[0] if len(tloss.size()) else 1 |
|
|
|
|
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0) |
|
|
|
|
loss_len = self.tloss.shape[0] if len(self.tloss.size()) else 1 |
|
|
|
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0) |
|
|
|
|
if rank in {-1, 0}: |
|
|
|
|
pbar.set_description( |
|
|
|
|
(" {} " + "{:.3f} " * (1 + loss_len) + ' {} ').format(f'{epoch + 1}/{self.args.epochs}', mem, |
|
|
|
|
*losses, batch["img"].shape[-1])) |
|
|
|
|
self.trigger_callbacks('on_batch_end') |
|
|
|
|
|
|
|
|
|
if rank in [-1, 0]: |
|
|
|
|
# validation |
|
|
|
|
# callback: on_val_start() |
|
|
|
|
self.trigger_callbacks('on_val_start') |
|
|
|
|
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) |
|
|
|
|
self.validate() |
|
|
|
|
# callback: on_val_end() |
|
|
|
|
self.metrics, self.fitness = self.validate() |
|
|
|
|
self.trigger_callbacks('on_val_end') |
|
|
|
|
|
|
|
|
|
# save model |
|
|
|
|
if (not self.args.nosave) or (self.epoch + 1 == self.args.epochs): |
|
|
|
|
self.save_model() |
|
|
|
|
# callback; on_model_save |
|
|
|
|
self.trigger_callbacks('on_model_save') |
|
|
|
|
|
|
|
|
|
self.epoch += 1 |
|
|
|
|
tnow = time.time() |
|
|
|
@ -216,9 +212,8 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# TODO: termination condition |
|
|
|
|
|
|
|
|
|
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours) \ |
|
|
|
|
\n{self.usage_help()}") |
|
|
|
|
# callback; on_train_end |
|
|
|
|
self.log(f"\nTraining complete ({(time.time() - self.train_time_start) / 3600:.3f} hours)") |
|
|
|
|
self.trigger_callbacks('on_train_end') |
|
|
|
|
dist.destroy_process_group() if world_size != 1 else None |
|
|
|
|
|
|
|
|
|
def save_model(self): |
|
|
|
@ -238,12 +233,6 @@ class BaseTrainer: |
|
|
|
|
torch.save(ckpt, self.best) |
|
|
|
|
del ckpt |
|
|
|
|
|
|
|
|
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0): |
|
|
|
|
""" |
|
|
|
|
Returns dataloader derived from torch.data.Dataloader |
|
|
|
|
""" |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
def get_dataset(self, data): |
|
|
|
|
""" |
|
|
|
|
Get train, val path from data dict if it exists. Returns None if data format is not recognized |
|
|
|
@ -259,12 +248,6 @@ class BaseTrainer: |
|
|
|
|
weights=get_model(model) if pretrained else None, |
|
|
|
|
data=self.data) # model |
|
|
|
|
|
|
|
|
|
def load_model(self, model_cfg, weights, data): |
|
|
|
|
raise NotImplementedError("This task trainer doesn't support loading cfg files") |
|
|
|
|
|
|
|
|
|
def get_validator(self): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
def optimizer_step(self): |
|
|
|
|
self.scaler.unscale_(self.optimizer) # unscale gradients |
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients |
|
|
|
@ -286,48 +269,55 @@ class BaseTrainer: |
|
|
|
|
# TODO: discuss validator class. Enforce that a validator metrics dict should contain |
|
|
|
|
"fitness" metric. |
|
|
|
|
""" |
|
|
|
|
self.metrics = self.validator(self) |
|
|
|
|
self.fitness = self.metrics.get("fitness", |
|
|
|
|
-self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found |
|
|
|
|
if not self.best_fitness or self.best_fitness < self.fitness: |
|
|
|
|
metrics = self.validator(self) |
|
|
|
|
fitness = metrics.get("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found |
|
|
|
|
if not self.best_fitness or self.best_fitness < fitness: |
|
|
|
|
self.best_fitness = self.fitness |
|
|
|
|
return metrics, fitness |
|
|
|
|
|
|
|
|
|
def set_model_attributes(self): |
|
|
|
|
def log(self, text, rank=-1): |
|
|
|
|
""" |
|
|
|
|
To set or update model parameters before training. |
|
|
|
|
Logs the given text to given ranks process if provided, otherwise logs to all ranks |
|
|
|
|
:param text: text to log |
|
|
|
|
:param rank: List[Int] |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
pass |
|
|
|
|
if rank in {-1, 0}: |
|
|
|
|
self.console.info(text) |
|
|
|
|
|
|
|
|
|
def build_targets(self, preds, targets): |
|
|
|
|
pass |
|
|
|
|
def load_model(self, model_cfg, weights, data): |
|
|
|
|
raise NotImplementedError("This task trainer doesn't support loading cfg files") |
|
|
|
|
|
|
|
|
|
def get_validator(self): |
|
|
|
|
raise NotImplementedError("get_validator function not implemented in trainer") |
|
|
|
|
|
|
|
|
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0): |
|
|
|
|
""" |
|
|
|
|
Returns dataloader derived from torch.data.Dataloader |
|
|
|
|
""" |
|
|
|
|
raise NotImplementedError("get_dataloader function not implemented in trainer") |
|
|
|
|
|
|
|
|
|
def criterion(self, preds, batch): |
|
|
|
|
""" |
|
|
|
|
Returns loss and individual loss items as Tensor |
|
|
|
|
""" |
|
|
|
|
pass |
|
|
|
|
raise NotImplementedError("criterion function not implemented in trainer") |
|
|
|
|
|
|
|
|
|
def progress_string(self): |
|
|
|
|
def label_loss_items(self, loss_items): |
|
|
|
|
""" |
|
|
|
|
Returns progress string depending on task type. |
|
|
|
|
Returns a loss dict with labelled training loss items tensor |
|
|
|
|
""" |
|
|
|
|
return '' |
|
|
|
|
# Not needed for classification but necessary for segmentation & detection |
|
|
|
|
return {"loss": loss_items} |
|
|
|
|
|
|
|
|
|
def usage_help(self): |
|
|
|
|
def set_model_attributes(self): |
|
|
|
|
""" |
|
|
|
|
Returns usage functionality. gets printed to the console after training. |
|
|
|
|
To set or update model parameters before training. |
|
|
|
|
""" |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
def log(self, text, rank=-1): |
|
|
|
|
""" |
|
|
|
|
Logs the given text to given ranks process if provided, otherwise logs to all ranks |
|
|
|
|
:param text: text to log |
|
|
|
|
:param rank: List[Int] |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
if rank in {-1, 0}: |
|
|
|
|
self.console.info(text) |
|
|
|
|
def build_targets(self, preds, targets): |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5): |
|
|
|
|