From 81120309acff37307b2226fbda12277ca1662f93 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:42:20 -0600 Subject: Add .to() --- tests/test_ema.py | 17 +++++++++++++++++ torch_ema/ema.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tests/test_ema.py b/tests/test_ema.py index aa43b14..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -134,3 +134,20 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) + + +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 diff --git a/torch_ema/ema.py b/torch_ema/ema.py index b3487cf..2aa3004 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -165,6 +165,28 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: -- cgit v1.2.3