diff options
author | Zehui Lin <zehui-lin_t@outlook.com> | 2021-01-21 09:52:29 +0800 |
---|---|---|
committer | Zehui Lin <zehui-lin_t@outlook.com> | 2021-01-21 09:52:29 +0800 |
commit | 205c76e4789709094ab7bb658b02f9abb789a66a (patch) | |
tree | c5052bce1100b02eefbaebdfeb46b917d98a069d | |
parent | 448588ac2401e9b68d5fb3f660d38d00cf633aa1 (diff) |
store and restore
-rw-r--r-- | torch_ema/ema.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 20c62c4..b819a04 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -60,16 +60,30 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(s_param.data) + def store(self, parameters): + """ + Save the current parameters for restore. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + self.collected_parameters = [] + for param in parameters: + self.collected_parameters.append(param.clone()) + def restore(self, parameters): """ - Restore the parameters given to the copy_to function. + Restore the parameters from the `store` function. Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to` function. + After the validation(or model saving), restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored original parameters. + updated with the stored moving averages. """ - for s_param, param in zip(self.collected_parameters, parameters): + for c_param, param in zip(self.collected_parameters, parameters): if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(c_param.data) |