MPS unified memory cache empty (#16078)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
pull/16061/head^2
Quet Almahdi Morris 3 months ago committed by GitHub
parent ccd2937aa1
commit 4d5afa7e0d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 12
      ultralytics/engine/trainer.py

@ -28,6 +28,7 @@ from ultralytics.utils import (
DEFAULT_CFG, DEFAULT_CFG,
LOCAL_RANK, LOCAL_RANK,
LOGGER, LOGGER,
MACOS,
RANK, RANK,
TQDM, TQDM,
__version__, __version__,
@ -453,7 +454,10 @@ class BaseTrainer:
self.stop |= epoch >= self.epochs # stop if exceeded epochs self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end") self.run_callbacks("on_fit_epoch_end")
gc.collect() gc.collect()
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors if MACOS:
torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy
else:
torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors
# Early Stopping # Early Stopping
if RANK != -1: # if DDP training if RANK != -1: # if DDP training
@ -475,7 +479,11 @@ class BaseTrainer:
self.plot_metrics() self.plot_metrics()
self.run_callbacks("on_train_end") self.run_callbacks("on_train_end")
gc.collect() gc.collect()
torch.cuda.empty_cache() if MACOS:
torch.mps.empty_cache()
else:
torch.cuda.empty_cache()
self.run_callbacks("teardown") self.run_callbacks("teardown")
def read_results_csv(self): def read_results_csv(self):

Loading…
Cancel
Save