aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--torch_ema/ema.py84
1 files changed, 49 insertions, 35 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index f6d8f6e..b3487cf 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -34,7 +34,7 @@ class ExponentialMovingAverage:
parameters = list(parameters)
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
- self.collected_params = []
+ self.collected_params = None
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
# if the model goes out of scope but the ExponentialMovingAverage
@@ -59,6 +59,13 @@ class ExponentialMovingAverage:
)
return parameters
else:
+ parameters = list(parameters)
+ if len(parameters) != len(self.shadow_params):
+ raise ValueError(
+ "Number of parameters passed as argument is different "
+ "from number of shadow parameters maintained by this "
+ "ExponentialMovingAverage"
+ )
return parameters
def update(
@@ -99,7 +106,7 @@ class ExponentialMovingAverage:
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
- Copy current parameters into given collection of parameters.
+ Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
@@ -125,9 +132,11 @@ class ExponentialMovingAverage:
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
- self.collected_params = [param.clone()
- for param in parameters
- if param.requires_grad]
+ self.collected_params = [
+ param.clone()
+ for param in parameters
+ if param.requires_grad
+ ]
def restore(
self,
@@ -146,6 +155,11 @@ class ExponentialMovingAverage:
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
+ if self.collected_params is None:
+ raise RuntimeError(
+ "This ExponentialMovingAverage has no `store()`ed weights "
+ "to `restore()`"
+ )
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
@@ -179,41 +193,41 @@ class ExponentialMovingAverage:
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
- # Consistant with torch.optim.Optimizer, cast things to current
- # device and dtype
- if len(self.shadow_params) > 0:
- device = self.shadow_params[0].device
- dtype = self.shadow_params[0].dtype
- else:
- device = None
- dtype = None
-
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
- # Cast shadow params:
- if device is not None:
- self.shadow_params = [
- p.to(device=device, dtype=dtype)
- if p.is_floating_point()
- else p.to(device=device)
- for p in self.shadow_params
- ]
self.collected_params = state_dict["collected_params"]
- assert isinstance(self.collected_params, list), \
- "collected_params must be a list"
- assert all(
- isinstance(p, torch.Tensor) for p in self.collected_params
- ), "collected_params must all be Tensors"
- # Cast collected params:
- if device is not None:
- self.collected_params = [
- p.to(device=device, dtype=dtype)
- if p.is_floating_point()
- else p.to(device=device)
- for p in self.collected_params
- ]
+ if self.collected_params is not None:
+ assert isinstance(self.collected_params, list), \
+ "collected_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.collected_params
+ ), "collected_params must all be Tensors"
+ assert len(self.collected_params) == len(self.shadow_params), \
+ "collected_params and shadow_params had different lengths"
+
+ if len(self.shadow_params) == len(self._params_refs):
+ # Consistant with torch.optim.Optimizer, cast things to consistant
+ # device and dtype with the parameters
+ params = [p() for p in self._params_refs]
+ # If parameters have been garbage collected, just load the state
+ # we were given without change.
+ if not any(p is None for p in params):
+ # ^ parameter references are still good
+ for i, p in enumerate(params):
+ self.shadow_params[i] = self.shadow_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ if self.collected_params is not None:
+ self.collected_params[i] = self.collected_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ else:
+ raise ValueError(
+ "Tried to `load_state_dict()` with the wrong number of "
+ "parameters in the saved state."
+ )