From 7c7f456710d8f4c538453fda96fb0ea1d4e84d7a Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Tue, 3 Sep 2024 18:03:01 +0800 Subject: [PATCH] Fix `torch.cuda.amp.GradScaler` warning (#15978) Co-authored-by: UltralyticsAssistant --- ultralytics/engine/trainer.py | 5 ++++- ultralytics/utils/torch_utils.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ultralytics/engine/trainer.py b/ultralytics/engine/trainer.py index c5b8a13f30..2d5fc62461 100644 --- a/ultralytics/engine/trainer.py +++ b/ultralytics/engine/trainer.py @@ -42,6 +42,7 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.files import get_latest_run from ultralytics.utils.torch_utils import ( + TORCH_2_4, EarlyStopping, ModelEMA, autocast, @@ -265,7 +266,9 @@ class BaseTrainer: if RANK > -1 and world_size > 1: # DDP dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) self.amp = bool(self.amp) # as boolean - self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp) + self.scaler = ( + torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp) + ) if world_size > 1: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index c2338e184b..16bcddadd0 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -40,6 +40,7 @@ except ImportError: TORCH_1_9 = check_version(torch.__version__, "1.9.0") TORCH_1_13 = check_version(torch.__version__, "1.13.0") TORCH_2_0 = check_version(torch.__version__, "2.0.0") +TORCH_2_4 = check_version(torch.__version__, "2.4.0") TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")