diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 2d5fc62461..03965a7296 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -28,6 +28,7 @@ from ultralytics.utils import ( DEFAULT_CFG, LOCAL_RANK, LOGGER, + MACOS, RANK, TQDM, __version__, @@ -453,7 +454,10 @@ class BaseTrainer: self.stop |= epoch >= self.epochs # stop if exceeded epochs self.run_callbacks("on_fit_epoch_end") gc.collect() - torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors + if MACOS: + torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy + else: + torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors # Early Stopping if RANK != -1: # if DDP training @@ -475,7 +479,11 @@ class BaseTrainer: self.plot_metrics() self.run_callbacks("on_train_end") gc.collect() - torch.cuda.empty_cache() + if MACOS: + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + self.run_callbacks("teardown") def read_results_csv(self):