aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorzehui-lin <zehui-lin_t@outlook.com>2021-01-15 11:29:57 +0800
committerzehui-lin <zehui-lin_t@outlook.com>2021-01-15 11:29:57 +0800
commit448588ac2401e9b68d5fb3f660d38d00cf633aa1 (patch)
tree1d7ae674d7c92c9fe2981dedb6ae30b4cd835736 /torch_ema/ema.py
parent27afe25b9fb9f0d05a87ae94e4e4ad9e92d70a85 (diff)
Add Feature: restore
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 4e9dd99..20c62c4 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -54,6 +54,22 @@ class ExponentialMovingAverage:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
+ self.collected_parameters = []
for s_param, param in zip(self.shadow_params, parameters):
+ self.collected_parameters.append(param.clone())
if param.requires_grad:
param.data.copy_(s_param.data)
+
+ def restore(self, parameters):
+ """
+ Restore the parameters given to the copy_to function.
+ Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored original parameters.
+ """
+ for s_param, param in zip(self.collected_parameters, parameters):
+ if param.requires_grad:
+ param.data.copy_(s_param.data)
+