diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index 03965a7296..1b10468105 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -454,7 +454,7 @@ class BaseTrainer: self.stop |= epoch >= self.epochs # stop if exceeded epochs self.run_callbacks("on_fit_epoch_end") gc.collect() - if MACOS: + if MACOS and self.device.type == "mps": torch.mps.empty_cache() # clear unified memory at end of epoch, may help MPS' management of 'unlimited' virtual memoy else: torch.cuda.empty_cache() # clear GPU memory at end of epoch, may help reduce CUDA out of memory errors @@ -479,7 +479,7 @@ class BaseTrainer: self.plot_metrics() self.run_callbacks("on_train_end") gc.collect() - if MACOS: + if MACOS and self.device.type == "mps": torch.mps.empty_cache() else: torch.cuda.empty_cache()