From c1882a4327689c726280a710f1274749d1f8e8f3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 24 Aug 2024 23:17:09 +0800 Subject: [PATCH] 'best.pt' inherit all-epochs results curves from 'last.pt' (#15791) Signed-off-by: UltralyticsAssistant Co-authored-by: UltralyticsAssistant --- ultralytics/engine/trainer.py | 15 ++++++++++----- ultralytics/hub/session.py | 2 +- ultralytics/utils/__init__.py | 2 +- ultralytics/utils/checks.py | 4 ++-- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 6ebe75366..3d2515acb 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -56,8 +56,6 @@ from ultralytics.utils.torch_utils import ( class BaseTrainer: """ - BaseTrainer. - A base class for creating trainers. Attributes: @@ -478,12 +476,16 @@ class BaseTrainer: torch.cuda.empty_cache() self.run_callbacks("teardown") + def read_results_csv(self): + """Read results.csv into a dict using pandas.""" + import pandas as pd # scope for faster 'import ultralytics' + + return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} + def save_model(self): """Save model training checkpoints with additional metadata.""" import io - import pandas as pd # scope for faster 'import ultralytics' - # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls) buffer = io.BytesIO() torch.save( @@ -496,7 +498,7 @@ class BaseTrainer: "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())), "train_args": vars(self.args), # save as dict "train_metrics": {**self.metrics, **{"fitness": self.fitness}}, - "train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}, + "train_results": self.read_results_csv(), "date": datetime.now().isoformat(), "version": __version__, "license": "AGPL-3.0 (https://ultralytics.com/license)", @@ -646,6 +648,9 @@ class BaseTrainer: if f.exists(): strip_optimizer(f) # strip optimizers if f is self.best: + if self.last.is_file(): # update best.pt train_metrics from last.pt + k = "train_results" + torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best) LOGGER.info(f"\nValidating {f}...") self.validator.args.plots = self.args.plots self.metrics = self.validator(model=f) diff --git a/ultralytics/hub/session.py b/ultralytics/hub/session.py index 500c23591..75608ddad 100644 --- a/ultralytics/hub/session.py +++ b/ultralytics/hub/session.py @@ -276,7 +276,7 @@ class HUBTrainingSession: # if request related to metrics upload and exceed retries if response is None and kwargs.get("metrics"): - self.metrics_upload_failed_queue.update(kwargs.get("metrics", None)) + self.metrics_upload_failed_queue.update(kwargs.get("metrics")) return response diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py index 9a2209725..f076b5b39 100644 --- a/ultralytics/utils/__init__.py +++ b/ultralytics/utils/__init__.py @@ -713,7 +713,7 @@ def colorstr(*input): In the second form, 'blue' and 'bold' will be applied by default. Args: - *input (str): A sequence of strings where the first n-1 strings are color and style arguments, + *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments, and the last string is the one to be colored. Supported Colors and Styles: diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py index b9bcef3f6..ca0555498 100644 --- a/ultralytics/utils/checks.py +++ b/ultralytics/utils/checks.py @@ -23,6 +23,7 @@ from ultralytics.utils import ( ASSETS, AUTOINSTALL, IS_COLAB, + IS_GIT_DIR, IS_JUPYTER, IS_KAGGLE, IS_PIP_PACKAGE, @@ -582,10 +583,9 @@ def check_yolo(verbose=True, device=""): def collect_system_info(): """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" - import psutil - from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR + from ultralytics.utils import ENVIRONMENT # scope to avoid circular import from ultralytics.utils.torch_utils import get_cpu_info ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB