Fix `torch.amp` has no attribute `GradScaler` (#14647)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
pull/14675/head
Laughing 4 months ago committed by GitHub
parent b7c90526c8
commit 1c351b5036
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 7
      ultralytics/engine/trainer.py

@ -41,7 +41,6 @@ from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_m
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run from ultralytics.utils.files import get_latest_run
from ultralytics.utils.torch_utils import ( from ultralytics.utils.torch_utils import (
TORCH_1_13,
EarlyStopping, EarlyStopping,
ModelEMA, ModelEMA,
autocast, autocast,
@ -266,11 +265,7 @@ class BaseTrainer:
if RANK > -1 and world_size > 1: # DDP if RANK > -1 and world_size > 1: # DDP
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None)
self.amp = bool(self.amp) # as boolean self.amp = bool(self.amp) # as boolean
self.scaler = ( self.scaler = torch.cuda.amp.GradScaler(enabled=self.amp)
torch.amp.GradScaler("cuda", enabled=self.amp)
if TORCH_1_13
else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if world_size > 1: if world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True) self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)

Loading…
Cancel
Save