aboutsummaryrefslogtreecommitdiff
path: root/torch_ema
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-02-08 08:07:57 -0300
committerGitHub <noreply@github.com>2021-02-08 08:07:57 -0300
commitb1de75dfeb7279bcee5e0450665566bafa0b6649 (patch)
tree39f6a13c973df60c1192e11eb363488955446388 /torch_ema
parent27afe25b9fb9f0d05a87ae94e4e4ad9e92d70a85 (diff)
parent359c155723f0b08f6e6c1784c7951e33d92b306d (diff)
Merge pull request #2 from Zehui-Lin/master
Add feature: store/restore
Diffstat (limited to 'torch_ema')
-rw-r--r--torch_ema/ema.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 4e9dd99..50f9aea 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -20,6 +20,7 @@ class ExponentialMovingAverage:
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
+ self.collected_parameters = []
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
@@ -57,3 +58,31 @@ class ExponentialMovingAverage:
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restore.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporary stored in.
+ """
+ self.collected_parameters = []
+ for param in parameters:
+ self.collected_parameters.append(param.clone())
+
+ def restore(self, parameters):
+ """
+ Restore the parameters from the `store` function.
+ Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process.
+ Store the parameters before the `copy_to` function.
+ After the validation(or model saving), restore the former parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_parameters, parameters):
+ if param.requires_grad:
+ param.data.copy_(c_param.data)
+