From 448588ac2401e9b68d5fb3f660d38d00cf633aa1 Mon Sep 17 00:00:00 2001 From: zehui-lin Date: Fri, 15 Jan 2021 11:29:57 +0800 Subject: Add Feature: restore --- torch_ema/ema.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'torch_ema') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 4e9dd99..20c62c4 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -54,6 +54,22 @@ class ExponentialMovingAverage: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ + self.collected_parameters = [] for s_param, param in zip(self.shadow_params, parameters): + self.collected_parameters.append(param.clone()) if param.requires_grad: param.data.copy_(s_param.data) + + def restore(self, parameters): + """ + Restore the parameters given to the copy_to function. + Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored original parameters. + """ + for s_param, param in zip(self.collected_parameters, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) + -- cgit v1.2.3 From 205c76e4789709094ab7bb658b02f9abb789a66a Mon Sep 17 00:00:00 2001 From: Zehui Lin Date: Thu, 21 Jan 2021 09:52:29 +0800 Subject: store and restore --- torch_ema/ema.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) (limited to 'torch_ema') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 20c62c4..b819a04 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -60,16 +60,30 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(s_param.data) + def store(self, parameters): + """ + Save the current parameters for restore. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + self.collected_parameters = [] + for param in parameters: + self.collected_parameters.append(param.clone()) + def restore(self, parameters): """ - Restore the parameters given to the copy_to function. + Restore the parameters from the `store` function. Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to` function. + After the validation(or model saving), restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored original parameters. + updated with the stored moving averages. """ - for s_param, param in zip(self.collected_parameters, parameters): + for c_param, param in zip(self.collected_parameters, parameters): if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(c_param.data) -- cgit v1.2.3 From 0ed9f9ddc9b93bbcd6d3828803f20a23cc500f5e Mon Sep 17 00:00:00 2001 From: Zehui Lin Date: Thu, 21 Jan 2021 09:55:41 +0800 Subject: fix comments --- torch_ema/ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'torch_ema') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index b819a04..8f39bc5 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -66,7 +66,7 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. + temporary stored in. """ self.collected_parameters = [] for param in parameters: @@ -81,7 +81,7 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. + updated with the stored parameters. """ for c_param, param in zip(self.collected_parameters, parameters): if param.requires_grad: -- cgit v1.2.3 From 359c155723f0b08f6e6c1784c7951e33d92b306d Mon Sep 17 00:00:00 2001 From: Zehui Lin Date: Mon, 8 Feb 2021 10:09:36 +0800 Subject: minor --- torch_ema/ema.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'torch_ema') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 8f39bc5..50f9aea 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -20,6 +20,7 @@ class ExponentialMovingAverage: """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') + self.collected_parameters = [] self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() @@ -54,9 +55,7 @@ class ExponentialMovingAverage: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ - self.collected_parameters = [] for s_param, param in zip(self.shadow_params, parameters): - self.collected_parameters.append(param.clone()) if param.requires_grad: param.data.copy_(s_param.data) -- cgit v1.2.3