|
|
|
@ -9,6 +9,7 @@ Simple training loop; Boilerplate that could apply to any arbitrary neural netwo |
|
|
|
|
import os |
|
|
|
|
import time |
|
|
|
|
from collections import defaultdict |
|
|
|
|
from copy import deepcopy |
|
|
|
|
from datetime import datetime |
|
|
|
|
from pathlib import Path |
|
|
|
|
from typing import Dict, Union |
|
|
|
@ -29,6 +30,7 @@ from ultralytics.yolo.utils import LOGGER, ROOT, TQDM_BAR_FORMAT |
|
|
|
|
from ultralytics.yolo.utils.checks import print_args |
|
|
|
|
from ultralytics.yolo.utils.files import increment_path, save_yaml |
|
|
|
|
from ultralytics.yolo.utils.modeling import get_model |
|
|
|
|
from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel |
|
|
|
|
|
|
|
|
|
DEFAULT_CONFIG = ROOT / "yolo/utils/configs/default.yaml" |
|
|
|
|
|
|
|
|
@ -63,6 +65,7 @@ class BaseTrainer: |
|
|
|
|
self.trainset, self.testset = self.get_dataset(self.data) |
|
|
|
|
if self.args.model: |
|
|
|
|
self.model = self.get_model(self.args.model) |
|
|
|
|
self.ema = None |
|
|
|
|
|
|
|
|
|
# epoch level metrics |
|
|
|
|
self.metrics = {} # handle metrics returned by validator |
|
|
|
@ -144,6 +147,7 @@ class BaseTrainer: |
|
|
|
|
self.validator = self.get_validator() |
|
|
|
|
print("created testloader :", rank) |
|
|
|
|
self.console.info(self.progress_string()) |
|
|
|
|
self.ema = ModelEMA(self.model) |
|
|
|
|
|
|
|
|
|
def _do_train(self, rank=-1, world_size=1): |
|
|
|
|
if world_size > 1: |
|
|
|
@ -196,6 +200,7 @@ class BaseTrainer: |
|
|
|
|
if rank in [-1, 0]: |
|
|
|
|
# validation |
|
|
|
|
# callback: on_val_start() |
|
|
|
|
self.ema.update_attr(self.model, include=['yaml', 'nc', 'args', 'names', 'stride', 'class_weights']) |
|
|
|
|
self.validate() |
|
|
|
|
# callback: on_val_end() |
|
|
|
|
|
|
|
|
@ -220,10 +225,10 @@ class BaseTrainer: |
|
|
|
|
ckpt = { |
|
|
|
|
'epoch': self.epoch, |
|
|
|
|
'best_fitness': self.best_fitness, |
|
|
|
|
'model': None, # deepcopy(ema.ema).half(), # deepcopy(de_parallel(model)).half(), |
|
|
|
|
'ema': None, # deepcopy(ema.ema).half(), |
|
|
|
|
'updates': None, # ema.updates, |
|
|
|
|
'optimizer': None, # optimizer.state_dict(), |
|
|
|
|
'model': deepcopy(de_parallel(self.model)).half(), |
|
|
|
|
'ema': deepcopy(self.ema.ema).half(), |
|
|
|
|
'updates': self.ema.updates, |
|
|
|
|
'optimizer': self.optimizer.state_dict(), |
|
|
|
|
'train_args': self.args, |
|
|
|
|
'date': datetime.now().isoformat()} |
|
|
|
|
|
|
|
|
@ -266,6 +271,8 @@ class BaseTrainer: |
|
|
|
|
self.scaler.step(self.optimizer) |
|
|
|
|
self.scaler.update() |
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
if self.ema: |
|
|
|
|
self.ema.update(self.model) |
|
|
|
|
|
|
|
|
|
def preprocess_batch(self, batch): |
|
|
|
|
""" |
|
|
|
|