from __future__ import division from __future__ import unicode_literals import torch # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. """ def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] def update(self, parameters): """ Update currently maintained parameters. Call this every time the parameters are updated, such as the result of the `optimizer.step()` call. Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ decay = self.decay if self.num_updates is not None: self.num_updates += 1 decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) def copy_to(self, parameters): """ Copies current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ self.collected_parameters = [] for s_param, param in zip(self.shadow_params, parameters): self.collected_parameters.append(param.clone()) if param.requires_grad: param.data.copy_(s_param.data) def store(self, parameters): """ Save the current parameters for restore. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ self.collected_parameters = [] for param in parameters: self.collected_parameters.append(param.clone()) def restore(self, parameters): """ Restore the parameters from the `store` function. Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` function. After the validation(or model saving), restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ for c_param, param in zip(self.collected_parameters, parameters): if param.requires_grad: param.data.copy_(c_param.data)