diff options
-rw-r--r-- | torch_ema/ema.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 20c62c4..b819a04 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -60,16 +60,30 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(s_param.data) + def store(self, parameters): + """ + Save the current parameters for restore. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + self.collected_parameters = [] + for param in parameters: + self.collected_parameters.append(param.clone()) + def restore(self, parameters): """ - Restore the parameters given to the copy_to function. + Restore the parameters from the `store` function. Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to` function. + After the validation(or model saving), restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored original parameters. + updated with the stored moving averages. """ - for s_param, param in zip(self.collected_parameters, parameters): + for c_param, param in zip(self.collected_parameters, parameters): if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(c_param.data) |