'best.pt' inherit all-epochs results curves from 'last.pt' (#15791)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
action-recog
Glenn Jocher 3 months ago committed by fcakyon
parent d4939aa58a
commit 679feabb9c
  1. 15
      ultralytics/engine/trainer.py
  2. 2
      ultralytics/hub/session.py
  3. 2
      ultralytics/utils/__init__.py
  4. 4
      ultralytics/utils/checks.py

@ -56,8 +56,6 @@ from ultralytics.utils.torch_utils import (
class BaseTrainer: class BaseTrainer:
""" """
BaseTrainer.
A base class for creating trainers. A base class for creating trainers.
Attributes: Attributes:
@ -478,12 +476,16 @@ class BaseTrainer:
torch.cuda.empty_cache() torch.cuda.empty_cache()
self.run_callbacks("teardown") self.run_callbacks("teardown")
def read_results_csv(self):
"""Read results.csv into a dict using pandas."""
import pandas as pd # scope for faster 'import ultralytics'
return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
def save_model(self): def save_model(self):
"""Save model training checkpoints with additional metadata.""" """Save model training checkpoints with additional metadata."""
import io import io
import pandas as pd # scope for faster 'import ultralytics'
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls) # Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
buffer = io.BytesIO() buffer = io.BytesIO()
torch.save( torch.save(
@ -496,7 +498,7 @@ class BaseTrainer:
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())), "optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
"train_args": vars(self.args), # save as dict "train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}}, "train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}, "train_results": self.read_results_csv(),
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"version": __version__, "version": __version__,
"license": "AGPL-3.0 (https://ultralytics.com/license)", "license": "AGPL-3.0 (https://ultralytics.com/license)",
@ -646,6 +648,9 @@ class BaseTrainer:
if f.exists(): if f.exists():
strip_optimizer(f) # strip optimizers strip_optimizer(f) # strip optimizers
if f is self.best: if f is self.best:
if self.last.is_file(): # update best.pt train_metrics from last.pt
k = "train_results"
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best)
LOGGER.info(f"\nValidating {f}...") LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f) self.metrics = self.validator(model=f)

@ -276,7 +276,7 @@ class HUBTrainingSession:
# if request related to metrics upload and exceed retries # if request related to metrics upload and exceed retries
if response is None and kwargs.get("metrics"): if response is None and kwargs.get("metrics"):
self.metrics_upload_failed_queue.update(kwargs.get("metrics", None)) self.metrics_upload_failed_queue.update(kwargs.get("metrics"))
return response return response

@ -716,7 +716,7 @@ def colorstr(*input):
In the second form, 'blue' and 'bold' will be applied by default. In the second form, 'blue' and 'bold' will be applied by default.
Args: Args:
*input (str): A sequence of strings where the first n-1 strings are color and style arguments, *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,
and the last string is the one to be colored. and the last string is the one to be colored.
Supported Colors and Styles: Supported Colors and Styles:

@ -23,6 +23,7 @@ from ultralytics.utils import (
ASSETS, ASSETS,
AUTOINSTALL, AUTOINSTALL,
IS_COLAB, IS_COLAB,
IS_GIT_DIR,
IS_JUPYTER, IS_JUPYTER,
IS_KAGGLE, IS_KAGGLE,
IS_PIP_PACKAGE, IS_PIP_PACKAGE,
@ -582,10 +583,9 @@ def check_yolo(verbose=True, device=""):
def collect_system_info(): def collect_system_info():
"""Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.""" """Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA."""
import psutil import psutil
from ultralytics.utils import ENVIRONMENT, IS_GIT_DIR from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
from ultralytics.utils.torch_utils import get_cpu_info from ultralytics.utils.torch_utils import get_cpu_info
ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB ram_info = psutil.virtual_memory().total / (1024**3) # Convert bytes to GB

Loading…
Cancel
Save