aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--torch_ema/ema.py23
1 files changed, 11 insertions, 12 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 50f9aea..447fd1e 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -20,11 +20,11 @@ class ExponentialMovingAverage:
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
- self.collected_parameters = []
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
+ self.collected_params = []
def update(self, parameters):
"""
@@ -49,7 +49,7 @@ class ExponentialMovingAverage:
def copy_to(self, parameters):
"""
- Copies current parameters into given collection of parameters.
+ Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
@@ -61,28 +61,27 @@ class ExponentialMovingAverage:
def store(self, parameters):
"""
- Save the current parameters for restore.
+ Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporary stored in.
+ temporarily stored.
"""
- self.collected_parameters = []
- for param in parameters:
- self.collected_parameters.append(param.clone())
+ self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
- Restore the parameters from the `store` function.
- Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process.
- Store the parameters before the `copy_to` function.
- After the validation(or model saving), restore the former parameters.
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
- for c_param, param in zip(self.collected_parameters, parameters):
+ for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)