|
|
|
@ -209,6 +209,13 @@ class BaseTrainer: |
|
|
|
|
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf'] |
|
|
|
|
else: |
|
|
|
|
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear |
|
|
|
|
# NOTE for testing |
|
|
|
|
# lambda x: max(1 - x / epochs, 0) * (1.0 - lrf) + lrf |
|
|
|
|
self.lf = [self.lf(n) for n in range(self.epochs)] |
|
|
|
|
|
|
|
|
|
# NOTE need to figure out how to make this work, especially for stagnating loss |
|
|
|
|
nudge_lr(self.lf, ...) |
|
|
|
|
|
|
|
|
|
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf) |
|
|
|
|
|
|
|
|
|
def _setup_ddp(self, world_size): |
|
|
|
@ -772,3 +779,30 @@ class BaseTrainer: |
|
|
|
|
f'{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)' |
|
|
|
|
) |
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
|
def nudge_lr(series:list[float], start:int, end:int, amplitude:float) -> list[float]: |
|
|
|
|
""" |
|
|
|
|
Adds a decaying 'nudge' in a linearly decreasing learning rate series to help overcome stagnating loss. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
series (list[float]): The input series of values. |
|
|
|
|
start (int): The starting index of the range to nudge. |
|
|
|
|
end (int): The ending index of the range to nudge. |
|
|
|
|
amplitude (float): The amplitude of the modified bump to add. |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
list[float]: The modified series with the nudge applied. |
|
|
|
|
""" |
|
|
|
|
_series = [s for s in series.copy()] # avoid inplace modification |
|
|
|
|
D = end - start |
|
|
|
|
s_half = start + D // 2 |
|
|
|
|
s1, s2 = slice(start, s_half), slice(s_half, end) |
|
|
|
|
|
|
|
|
|
# Calculate the values for the augmentation |
|
|
|
|
y = [amplitude * (1 - math.exp(-2 * (v / (D)) ** 2)) for v in range(D)] |
|
|
|
|
|
|
|
|
|
# Insert new values |
|
|
|
|
_series[s1] = [v + y[e - start] for e,v in enumerate(_series[s1], start=start)] |
|
|
|
|
_series[s2] = [v + y[end - e - 1] for e,v in enumerate(_series[s2], start=s_half)] |
|
|
|
|
|
|
|
|
|
return _series |
|
|
|
|