aboutsummaryrefslogtreecommitdiff
path: root/torch_ema
diff options
context:
space:
mode:
authorZehui Lin <zehui-lin_t@outlook.com>2021-01-21 09:52:29 +0800
committerZehui Lin <zehui-lin_t@outlook.com>2021-01-21 09:52:29 +0800
commit205c76e4789709094ab7bb658b02f9abb789a66a (patch)
treec5052bce1100b02eefbaebdfeb46b917d98a069d /torch_ema
parent448588ac2401e9b68d5fb3f660d38d00cf633aa1 (diff)
store and restore
Diffstat (limited to 'torch_ema')
-rw-r--r--torch_ema/ema.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 20c62c4..b819a04 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -60,16 +60,30 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(s_param.data)
+ def store(self, parameters):
+ """
+ Save the current parameters for restore.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages.
+ """
+ self.collected_parameters = []
+ for param in parameters:
+ self.collected_parameters.append(param.clone())
+
def restore(self, parameters):
"""
- Restore the parameters given to the copy_to function.
+ Restore the parameters from the `store` function.
Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process.
+ Store the parameters before the `copy_to` function.
+ After the validation(or model saving), restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored original parameters.
+ updated with the stored moving averages.
"""
- for s_param, param in zip(self.collected_parameters, parameters):
+ for c_param, param in zip(self.collected_parameters, parameters):
if param.requires_grad:
- param.data.copy_(s_param.data)
+ param.data.copy_(c_param.data)