# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import paddle class EMA(object): """ The implementation of Exponential Moving Average for the trainable parameters. Args: model (nn.Layer): The model for applying EMA. decay (float, optional): Decay is used to calculate ema_variable by `ema_variable = decay * ema_variable + (1 - decay) * new_variable`. Default: 0.99. Returns: None Examples: .. code-block:: python # 1. Define model and dataset # 2. Create EMA ema = EMA(model, decay=0.99) # 3. Train stage for data in dataloader(): ... optimizer.step() ema.step() # 4. Evaluate stage ema.apply() # Use the EMA data to replace the origin data for data in dataloader(): ... ema.restore() # Restore the origin data to the model """ def __init__(self, model, decay=0.99): super().__init__() assert isinstance(model, paddle.nn.Layer), \ "The model should be the instance of paddle.nn.Layer." assert decay >= 0 and decay <= 1.0, \ "The decay = {} should in [0.0, 1.0]".format(decay) self._model = model self._decay = decay self._ema_data = {} self._backup_data = {} for name, param in self._model.named_parameters(): if not param.stop_gradient: self._ema_data[name] = param.numpy() def step(self): """ Calculate the EMA data for all trainable parameters. """ for name, param in self._model.named_parameters(): if not param.stop_gradient: assert name in self._ema_data, \ "The param ({}) isn't in the model".format(name) self._ema_data[name] = self._decay * self._ema_data[name] \ + (1.0 - self._decay) * param.numpy() def apply(self): """ Save the origin data and use the EMA data to replace the origin data. """ for name, param in self._model.named_parameters(): if not param.stop_gradient: assert name in self._ema_data, \ "The param ({}) isn't in the model".format(name) self._backup_data[name] = param.numpy() param.set_value(self._ema_data[name]) def restore(self): """ Restore the origin data to the model. """ for name, param in self._model.named_parameters(): if not param.stop_gradient: assert name in self._backup_data, \ "The param ({}) isn't in the model".format(name) param.set_value(self._backup_data[name]) self._backup_data = {}