From 205c76e4789709094ab7bb658b02f9abb789a66a Mon Sep 17 00:00:00 2001 From: Zehui Lin Date: Thu, 21 Jan 2021 09:52:29 +0800 Subject: store and restore --- torch_ema/ema.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 20c62c4..b819a04 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -60,16 +60,30 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(s_param.data) + def store(self, parameters): + """ + Save the current parameters for restore. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + self.collected_parameters = [] + for param in parameters: + self.collected_parameters.append(param.clone()) + def restore(self, parameters): """ - Restore the parameters given to the copy_to function. + Restore the parameters from the `store` function. Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to` function. + After the validation(or model saving), restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored original parameters. + updated with the stored moving averages. """ - for s_param, param in zip(self.collected_parameters, parameters): + for c_param, param in zip(self.collected_parameters, parameters): if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(c_param.data) -- cgit v1.2.3