diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-02-08 08:07:57 -0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-08 08:07:57 -0300 |
commit | b1de75dfeb7279bcee5e0450665566bafa0b6649 (patch) | |
tree | 39f6a13c973df60c1192e11eb363488955446388 /torch_ema/ema.py | |
parent | 27afe25b9fb9f0d05a87ae94e4e4ad9e92d70a85 (diff) | |
parent | 359c155723f0b08f6e6c1784c7951e33d92b306d (diff) |
Merge pull request #2 from Zehui-Lin/master
Add feature: store/restore
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 4e9dd99..50f9aea 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -20,6 +20,7 @@ class ExponentialMovingAverage: """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') + self.collected_parameters = [] self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() @@ -57,3 +58,31 @@ class ExponentialMovingAverage: for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) + + def store(self, parameters): + """ + Save the current parameters for restore. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporary stored in. + """ + self.collected_parameters = [] + for param in parameters: + self.collected_parameters.append(param.clone()) + + def restore(self, parameters): + """ + Restore the parameters from the `store` function. + Usually used in validation. Want to validate the model with EMA parameters without affecting the original optimization process. + Store the parameters before the `copy_to` function. + After the validation(or model saving), restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_parameters, parameters): + if param.requires_grad: + param.data.copy_(c_param.data) + |