diff options
author | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 13:42:20 -0600 |
---|---|---|
committer | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 13:42:20 -0600 |
commit | 81120309acff37307b2226fbda12277ca1662f93 (patch) | |
tree | b2c9b1693bda9275e141b38f06052eafb91e0e0b /torch_ema/ema.py | |
parent | 81a99ed1ec6f576d6b8004c7000ca0bc023e7483 (diff) |
Add .to()
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index b3487cf..2aa3004 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -165,6 +165,28 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: |