aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 13:42:20 -0600
committerAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 13:42:20 -0600
commit81120309acff37307b2226fbda12277ca1662f93 (patch)
treeb2c9b1693bda9275e141b38f06052eafb91e0e0b /torch_ema/ema.py
parent81a99ed1ec6f576d6b8004c7000ca0bc023e7483 (diff)
Add .to()
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index b3487cf..2aa3004 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -165,6 +165,28 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(c_param.data)
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+ if self.collected_params is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]
+ return
+
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned: