diff options
author | zehui-lin <zehui-lin_t@outlook.com> | 2021-01-15 11:29:57 +0800 |
---|---|---|
committer | zehui-lin <zehui-lin_t@outlook.com> | 2021-01-15 11:29:57 +0800 |
commit | 448588ac2401e9b68d5fb3f660d38d00cf633aa1 (patch) | |
tree | 1d7ae674d7c92c9fe2981dedb6ae30b4cd835736 /torch_ema/ema.py | |
parent | 27afe25b9fb9f0d05a87ae94e4e4ad9e92d70a85 (diff) |
Add Feature: restore
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 4e9dd99..20c62c4 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -54,6 +54,22 @@ class ExponentialMovingAverage: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ + self.collected_parameters = [] for s_param, param in zip(self.shadow_params, parameters): + self.collected_parameters.append(param.clone()) if param.requires_grad: param.data.copy_(s_param.data) + + def restore(self, parameters): + """ + Restore the parameters given to the copy_to function. + Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored original parameters. + """ + for s_param, param in zip(self.collected_parameters, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) + |