diff options
-rw-r--r-- | torch_ema/ema.py | 84 |
1 files changed, 49 insertions, 35 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index f6d8f6e..b3487cf 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -34,7 +34,7 @@ class ExponentialMovingAverage: parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] - self.collected_params = [] + self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage @@ -59,6 +59,13 @@ class ExponentialMovingAverage: ) return parameters else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) return parameters def update( @@ -99,7 +106,7 @@ class ExponentialMovingAverage: parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ - Copy current parameters into given collection of parameters. + Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be @@ -125,9 +132,11 @@ class ExponentialMovingAverage: `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() - for param in parameters - if param.requires_grad] + self.collected_params = [ + param.clone() + for param in parameters + if param.requires_grad + ] def restore( self, @@ -146,6 +155,11 @@ class ExponentialMovingAverage: parameters with which this `ExponentialMovingAverage` was initialized will be used. """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: @@ -179,41 +193,41 @@ class ExponentialMovingAverage: assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" - # Consistant with torch.optim.Optimizer, cast things to current - # device and dtype - if len(self.shadow_params) > 0: - device = self.shadow_params[0].device - dtype = self.shadow_params[0].dtype - else: - device = None - dtype = None - self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" - # Cast shadow params: - if device is not None: - self.shadow_params = [ - p.to(device=device, dtype=dtype) - if p.is_floating_point() - else p.to(device=device) - for p in self.shadow_params - ] self.collected_params = state_dict["collected_params"] - assert isinstance(self.collected_params, list), \ - "collected_params must be a list" - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" - # Cast collected params: - if device is not None: - self.collected_params = [ - p.to(device=device, dtype=dtype) - if p.is_floating_point() - else p.to(device=device) - for p in self.collected_params - ] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) |