From 1e547e60a0b632f55b975f553fe9300340e920e0 Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 2 Apr 2024 17:55:11 +0800 Subject: [PATCH] Fix learning rate gap on resume (#9468) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher Co-authored-by: EunChan Kim Co-authored-by: Lakshantha Dissanayake Co-authored-by: RizwanMunawar Co-authored-by: gs80140 --- ultralytics/engine/trainer.py | 19 ++++++++++--------- ultralytics/utils/torch_utils.py | 5 +++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index f92e815e9..c2391270d 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -331,6 +331,10 @@ class BaseTrainer: while True: self.epoch = epoch self.run_callbacks("on_train_epoch_start") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' + self.scheduler.step() + self.model.train() if RANK != -1: self.train_loader.sampler.set_epoch(epoch) @@ -426,15 +430,12 @@ class BaseTrainer: t = time.time() self.epoch_time = t - self.epoch_time_start self.epoch_time_start = t - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()' - if self.args.time: - mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) - self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) - self._setup_scheduler() - self.scheduler.last_epoch = self.epoch # do not move - self.stop |= epoch >= self.epochs # stop if exceeded epochs - self.scheduler.step() + if self.args.time: + mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1) + self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time) + self._setup_scheduler() + self.scheduler.last_epoch = self.epoch # do not move + self.stop |= epoch >= self.epochs # stop if exceeded epochs self.run_callbacks("on_fit_epoch_end") torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 96154d60e..32eae2c80 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -16,7 +16,7 @@ import torch.nn as nn import torch.nn.functional as F import torchvision -from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__ +from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, __version__ from ultralytics.utils.checks import PYTHON_VERSION, check_version try: @@ -614,8 +614,9 @@ class EarlyStopping: self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch stop = delta >= self.patience # stop training if patience exceeded if stop: + prefix = colorstr("EarlyStopping: ") LOGGER.info( - f"Stopping training early as no improvement observed in last {self.patience} epochs. " + f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. " f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n" f"To update EarlyStopping(patience={self.patience}) pass a new patience value, " f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."