diff options
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 73 |
1 files changed, 62 insertions, 11 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 0233c78..2e8eb6f 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -1,7 +1,8 @@ from __future__ import division from __future__ import unicode_literals -from typing import Iterable +from typing import Iterable, Optional +import weakref import torch @@ -13,8 +14,8 @@ class ExponentialMovingAverage: Maintains (exponential) moving average of a set of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the result of - `model.parameters()`. + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. @@ -29,11 +30,40 @@ class ExponentialMovingAverage: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] - def update(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def _get_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this " + "ExponentialMovingAverage " + "was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep " + "the model to which they belong from being garbage " + "collected." + ) + return parameters + else: + return parameters + + def update( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Update currently maintained parameters. @@ -42,8 +72,11 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + parameters = self._get_parameters(parameters) decay = self.decay if self.num_updates is not None: self.num_updates += 1 @@ -60,31 +93,46 @@ class ExponentialMovingAverage: tmp.mul_(one_minus_decay) s_param.sub_(tmp) - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def copy_to( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) - def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def store( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ + parameters = self._get_parameters(parameters) self.collected_params = [param.clone() for param in parameters if param.requires_grad] - def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + + def restore( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the @@ -94,8 +142,11 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) |