diff options
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 114 |
1 files changed, 89 insertions, 25 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..3bcb465 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy +import contextlib import torch @@ -34,7 +35,7 @@ class ExponentialMovingAverage: parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] - self.collected_params = [] + self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage @@ -59,6 +60,13 @@ class ExponentialMovingAverage: ) return parameters else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) return parameters def update( @@ -72,10 +80,10 @@ class ExponentialMovingAverage: the `optimizer.step()` call. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay @@ -99,13 +107,13 @@ class ExponentialMovingAverage: parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ - Copy current parameters into given collection of parameters. + Copy current averaged parameters into given collection of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): @@ -120,14 +128,16 @@ class ExponentialMovingAverage: Save the current parameters for restoring later. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() - for param in parameters - if param.requires_grad] + self.collected_params = [ + param.clone() + for param in parameters + if param.requires_grad + ] def restore( self, @@ -141,16 +151,43 @@ class ExponentialMovingAverage: restore the former parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: @@ -178,15 +215,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + self.collected_params = state_dict["collected_params"] - assert isinstance(self.collected_params, list), \ - "collected_params must be a list" - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) |