|
|
|
@ -56,8 +56,6 @@ from ultralytics.utils.torch_utils import ( |
|
|
|
|
|
|
|
|
|
class BaseTrainer: |
|
|
|
|
""" |
|
|
|
|
BaseTrainer. |
|
|
|
|
|
|
|
|
|
A base class for creating trainers. |
|
|
|
|
|
|
|
|
|
Attributes: |
|
|
|
@ -478,12 +476,16 @@ class BaseTrainer: |
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
self.run_callbacks("teardown") |
|
|
|
|
|
|
|
|
|
def read_results_csv(self): |
|
|
|
|
"""Read results.csv into a dict using pandas.""" |
|
|
|
|
import pandas as pd # scope for faster 'import ultralytics' |
|
|
|
|
|
|
|
|
|
return {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()} |
|
|
|
|
|
|
|
|
|
def save_model(self): |
|
|
|
|
"""Save model training checkpoints with additional metadata.""" |
|
|
|
|
import io |
|
|
|
|
|
|
|
|
|
import pandas as pd # scope for faster 'import ultralytics' |
|
|
|
|
|
|
|
|
|
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls) |
|
|
|
|
buffer = io.BytesIO() |
|
|
|
|
torch.save( |
|
|
|
@ -496,7 +498,7 @@ class BaseTrainer: |
|
|
|
|
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())), |
|
|
|
|
"train_args": vars(self.args), # save as dict |
|
|
|
|
"train_metrics": {**self.metrics, **{"fitness": self.fitness}}, |
|
|
|
|
"train_results": {k.strip(): v for k, v in pd.read_csv(self.csv).to_dict(orient="list").items()}, |
|
|
|
|
"train_results": self.read_results_csv(), |
|
|
|
|
"date": datetime.now().isoformat(), |
|
|
|
|
"version": __version__, |
|
|
|
|
"license": "AGPL-3.0 (https://ultralytics.com/license)", |
|
|
|
@ -646,6 +648,9 @@ class BaseTrainer: |
|
|
|
|
if f.exists(): |
|
|
|
|
strip_optimizer(f) # strip optimizers |
|
|
|
|
if f is self.best: |
|
|
|
|
if self.last.is_file(): # update best.pt train_metrics from last.pt |
|
|
|
|
k = "train_results" |
|
|
|
|
torch.save({**torch.load(self.best), **{k: torch.load(self.last)[k]}}, self.best) |
|
|
|
|
LOGGER.info(f"\nValidating {f}...") |
|
|
|
|
self.validator.args.plots = self.args.plots |
|
|
|
|
self.metrics = self.validator(model=f) |
|
|
|
|