aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 13:42:20 -0600
committerAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 13:42:20 -0600
commit81120309acff37307b2226fbda12277ca1662f93 (patch)
treeb2c9b1693bda9275e141b38f06052eafb91e0e0b
parent81a99ed1ec6f576d6b8004c7000ca0bc023e7483 (diff)
Add .to()
-rw-r--r--tests/test_ema.py17
-rw-r--r--torch_ema/ema.py22
2 files changed, 39 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index aa43b14..edcea4c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -134,3 +134,20 @@ def test_explicit_params():
ema.update(model2.parameters())
ema.copy_to()
assert not torch.all(model.weight == 0.0)
+
+
+def test_to():
+ m = torch.nn.Linear(11, 3)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ assert ema.shadow_params[0].dtype == torch.get_default_dtype()
+ ema.to(dtype=torch.float16)
+ assert ema.shadow_params[0].dtype == torch.float16
+ ema.store()
+ # we store whatever we get
+ assert ema.collected_params[0].dtype == torch.get_default_dtype()
+ m = m.to(torch.float16)
+ ema.store(m.parameters())
+ assert ema.collected_params[0].dtype == torch.float16
+ ema.to(dtype=torch.float64)
+ assert ema.collected_params[0].dtype == torch.float64
+ assert ema.shadow_params[0].dtype == torch.float64
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index b3487cf..2aa3004 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -165,6 +165,28 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(c_param.data)
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+ if self.collected_params is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]
+ return
+
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned: