[Feat] Support ReduceOnPlateau (#107)

own
Lin Manhui 3 years ago committed by GitHub
parent de61f6007f
commit 3af03678d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      paddlers/tasks/base.py

@ -374,6 +374,11 @@ class BaseModel(metaclass=ModelMeta):
lr = self.optimizer.get_lr() lr = self.optimizer.get_lr()
if isinstance(self.optimizer._learning_rate, if isinstance(self.optimizer._learning_rate,
paddle.optimizer.lr.LRScheduler): paddle.optimizer.lr.LRScheduler):
# If ReduceOnPlateau is used as the scheduler, use the loss value as the metric.
if isinstance(self.optimizer._learning_rate,
paddle.optimizer.lr.ReduceOnPlateau):
self.optimizer._learning_rate.step(loss.item())
else:
self.optimizer._learning_rate.step() self.optimizer._learning_rate.step()
train_avg_metrics.update(outputs) train_avg_metrics.update(outputs)

Loading…
Cancel
Save