diff options
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py new file mode 100644 index 0000000..32ed7ca --- /dev/null +++ b/torch_ema/ema.py @@ -0,0 +1,59 @@ +from __future__ import division +from __future__ import unicode_literals + +import torch + + +# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + """ + def __init__(self, parameters, decay, use_num_updates=True): + """ + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the result of + `model.parameters()`. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.shadow_params = [p.clone().detach() + for p in parameters if p.requires_grad] + + def update(self, parameters): + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. + """ + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + + def copy_to(self, parameters): + """ + Copies current parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) |