aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-07-02 13:22:55 +0200
committerGitHub <noreply@github.com>2021-07-02 13:22:55 +0200
commit72ac3d3333c3dc1d95eacdedbdb5a0132958973a (patch)
tree402dd92d2b1cf67aa7909c4eb87339b1a0acfc4c /torch_ema/ema.py
parenta8223abaad1da1293f350d80b636a8d67b2d58a5 (diff)
parente668ae1e0a757cf8217e926be9ae228676fbe17b (diff)
Merge pull request #7 from Linux-cpp-lisp/device_and_dtype
Device and dtype improvements
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py114
1 files changed, 89 insertions, 25 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 6c0415f..3bcb465 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from typing import Iterable, Optional
import weakref
import copy
+import contextlib
import torch
@@ -34,7 +35,7 @@ class ExponentialMovingAverage:
parameters = list(parameters)
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
- self.collected_params = []
+ self.collected_params = None
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
# if the model goes out of scope but the ExponentialMovingAverage
@@ -59,6 +60,13 @@ class ExponentialMovingAverage:
)
return parameters
else:
+ parameters = list(parameters)
+ if len(parameters) != len(self.shadow_params):
+ raise ValueError(
+ "Number of parameters passed as argument is different "
+ "from number of shadow parameters maintained by this "
+ "ExponentialMovingAverage"
+ )
return parameters
def update(
@@ -72,10 +80,10 @@ class ExponentialMovingAverage:
the `optimizer.step()` call.
Args:
- parameters: Iterable of `torch.nn.Parameter`; usually the same set of
- parameters used to initialize this object. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
parameters = self._get_parameters(parameters)
decay = self.decay
@@ -99,13 +107,13 @@ class ExponentialMovingAverage:
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
- Copy current parameters into given collection of parameters.
+ Copy current averaged parameters into given collection of parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored moving averages. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
@@ -120,14 +128,16 @@ class ExponentialMovingAverage:
Save the current parameters for restoring later.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored. If `None`, the parameters of with which this
- `ExponentialMovingAverage` was initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored. If `None`, the parameters of with which this
+ `ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
- self.collected_params = [param.clone()
- for param in parameters
- if param.requires_grad]
+ self.collected_params = [
+ param.clone()
+ for param in parameters
+ if param.requires_grad
+ ]
def restore(
self,
@@ -141,16 +151,43 @@ class ExponentialMovingAverage:
restore the former parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
+ if self.collected_params is None:
+ raise RuntimeError(
+ "This ExponentialMovingAverage has no `store()`ed weights "
+ "to `restore()`"
+ )
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+ if self.collected_params is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]
+ return
+
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned:
@@ -178,15 +215,42 @@ class ExponentialMovingAverage:
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
+
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
+
self.collected_params = state_dict["collected_params"]
- assert isinstance(self.collected_params, list), \
- "collected_params must be a list"
- assert all(
- isinstance(p, torch.Tensor) for p in self.collected_params
- ), "collected_params must all be Tensors"
+ if self.collected_params is not None:
+ assert isinstance(self.collected_params, list), \
+ "collected_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.collected_params
+ ), "collected_params must all be Tensors"
+ assert len(self.collected_params) == len(self.shadow_params), \
+ "collected_params and shadow_params had different lengths"
+
+ if len(self.shadow_params) == len(self._params_refs):
+ # Consistant with torch.optim.Optimizer, cast things to consistant
+ # device and dtype with the parameters
+ params = [p() for p in self._params_refs]
+ # If parameters have been garbage collected, just load the state
+ # we were given without change.
+ if not any(p is None for p in params):
+ # ^ parameter references are still good
+ for i, p in enumerate(params):
+ self.shadow_params[i] = self.shadow_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ if self.collected_params is not None:
+ self.collected_params[i] = self.collected_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ else:
+ raise ValueError(
+ "Tried to `load_state_dict()` with the wrong number of "
+ "parameters in the saved state."
+ )