|
|
|
@ -387,21 +387,22 @@ class BaseTrainer: |
|
|
|
|
|
|
|
|
|
# Backward |
|
|
|
|
self.scaler.scale(self.loss).backward() |
|
|
|
|
self.optimizer_step() |
|
|
|
|
|
|
|
|
|
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html |
|
|
|
|
if ni - last_opt_step >= self.accumulate: |
|
|
|
|
self.optimizer_step() |
|
|
|
|
last_opt_step = ni |
|
|
|
|
|
|
|
|
|
# Timed stopping |
|
|
|
|
if self.args.time: |
|
|
|
|
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) |
|
|
|
|
if RANK != -1: # if DDP training |
|
|
|
|
broadcast_list = [self.stop if RANK == 0 else None] |
|
|
|
|
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks |
|
|
|
|
self.stop = broadcast_list[0] |
|
|
|
|
if self.stop: # training time exceeded |
|
|
|
|
break |
|
|
|
|
# if ni - last_opt_step >= self.accumulate: |
|
|
|
|
# self.optimizer_step() |
|
|
|
|
# last_opt_step = ni |
|
|
|
|
# |
|
|
|
|
# # Timed stopping |
|
|
|
|
# if self.args.time: |
|
|
|
|
# self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600) |
|
|
|
|
# if RANK != -1: # if DDP training |
|
|
|
|
# broadcast_list = [self.stop if RANK == 0 else None] |
|
|
|
|
# dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks |
|
|
|
|
# self.stop = broadcast_list[0] |
|
|
|
|
# if self.stop: # training time exceeded |
|
|
|
|
# break |
|
|
|
|
|
|
|
|
|
# Log |
|
|
|
|
if RANK in {-1, 0}: |
|
|
|
|