From 4d5afa7e0dcfd49072b28fc8715a7c3d49beabaf Mon Sep 17 00:00:00 2001 From: Quet Almahdi Morris Date: Sat, 7 Sep 2024 12:38:03 -0500 Subject: [PATCH] MPS unified memory cache empty (#16078) Co-authored-by: UltralyticsAssistant Co-authored-by: Glenn Jocher --- ultralytics/engine/trainer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 2d5fc62461..03965a7296 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -28,6 +28,7 @@ from ultralytics.utils import ( DEFAULT_CFG, LOCAL_RANK, LOGGER, + MACOS, RANK, TQDM, __version__, @@ -453,7 +454,10 @@ class BaseTrainer: self.stop |= epoch >= self.epochs # stop if exceeded epochs self.run_callbacks("on_fit_epoch_end") gc.collect() - torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors + if MACOS: + torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy + else: + torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors # Early Stopping if RANK != -1: # if DDP training @@ -475,7 +479,11 @@ class BaseTrainer: self.plot_metrics() self.run_callbacks("on_train_end") gc.collect() - torch.cuda.empty_cache() + if MACOS: + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + self.run_callbacks("teardown") def read_results_csv(self):