Merge branch 'main' into quan

mct-2.1.1
Francesco Mattioli 5 months ago committed by GitHub
commit addd3221b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 2
      ultralytics/__init__.py
  2. 15
      ultralytics/engine/trainer.py

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.8"
__version__ = "8.3.9"
import os

@ -469,11 +469,9 @@ class BaseTrainer:
if RANK in {-1, 0}:
# Do final val with best.pt
epochs = epoch - self.start_epoch + 1 # total training epochs
seconds = time.time() - self.train_time_start # total training seconds
LOGGER.info(f"\n{epochs} epochs completed in {seconds / 3600:.3f} hours.")
seconds = time.time() - self.train_time_start
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
self.final_eval()
self.validator.metrics.training = {"epochs": epochs, "seconds": seconds} # add training speed
if self.args.plots:
self.plot_metrics()
self.run_callbacks("on_train_end")
@ -504,7 +502,7 @@ class BaseTrainer:
"""Read results.csv into a dict using pandas."""
import pandas as pd # scope for faster 'import ultralytics'
return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}
return pd.read_csv(self.csv).to_dict(orient="list")
def save_model(self):
"""Save model training checkpoints with additional metadata."""
@ -654,10 +652,11 @@ class BaseTrainer:
def save_metrics(self, metrics):
"""Saves training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 1 # number of cols
s = "" if self.csv.exists() else (("%23s," * n % tuple(["epoch"] + keys)).rstrip(",") + "\n") # header
n = len(metrics) + 2 # number of cols
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
t = time.time() - self.train_time_start
with open(self.csv, "a") as f:
f.write(s + ("%23.5g," * n % tuple([self.epoch + 1] + vals)).rstrip(",") + "\n")
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
def plot_metrics(self):
"""Plot and display metrics visually."""

Loading…
Cancel
Save