diff options
-rw-r--r-- | torch_ema/ema.py | 23 |
1 files changed, 11 insertions, 12 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 50f9aea..447fd1e 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -20,11 +20,11 @@ class ExponentialMovingAverage: """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') - self.collected_parameters = [] self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] + self.collected_params = [] def update(self, parameters): """ @@ -49,7 +49,7 @@ class ExponentialMovingAverage: def copy_to(self, parameters): """ - Copies current parameters into given collection of parameters. + Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be @@ -61,28 +61,27 @@ class ExponentialMovingAverage: def store(self, parameters): """ - Save the current parameters for restore. + Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporary stored in. + temporarily stored. """ - self.collected_parameters = [] - for param in parameters: - self.collected_parameters.append(param.clone()) + self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ - Restore the parameters from the `store` function. - Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. - Store the parameters before the `copy_to` function. - After the validation(or model saving), restore the former parameters. + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ - for c_param, param in zip(self.collected_parameters, parameters): + for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) |