|
|
@ -217,7 +217,7 @@ class BaseTrainer: |
|
|
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them |
|
|
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them |
|
|
|
self.amp = torch.tensor(check_amp(self.model), device=self.device) |
|
|
|
self.amp = torch.tensor(check_amp(self.model), device=self.device) |
|
|
|
callbacks.default_callbacks = callbacks_backup # restore callbacks |
|
|
|
callbacks.default_callbacks = callbacks_backup # restore callbacks |
|
|
|
if RANK > -1: # DDP |
|
|
|
if RANK > -1 and world_size > 1: # DDP |
|
|
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) |
|
|
|
dist.broadcast(self.amp, src=0) # broadcast the tensor from rank 0 to all other ranks (returns None) |
|
|
|
self.amp = bool(self.amp) # as boolean |
|
|
|
self.amp = bool(self.amp) # as boolean |
|
|
|
self.scaler = amp.GradScaler(enabled=self.amp) |
|
|
|
self.scaler = amp.GradScaler(enabled=self.amp) |
|
|
|