diff options
author | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 12:49:49 -0600 |
---|---|---|
committer | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 12:49:49 -0600 |
commit | 984e7d5fbc8eb916a24086a7053335516c14ff34 (patch) | |
tree | 4aa29075097638a6b19e4e561ccfe62d965ea7a3 /torch_ema | |
parent | a8223abaad1da1293f350d80b636a8d67b2d58a5 (diff) |
Cast state dict to current device/dtype
Diffstat (limited to 'torch_ema')
-rw-r--r-- | torch_ema/ema.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..f6d8f6e 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -178,15 +178,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + + # Consistant with torch.optim.Optimizer, cast things to current + # device and dtype + if len(self.shadow_params) > 0: + device = self.shadow_params[0].device + dtype = self.shadow_params[0].dtype + else: + device = None + dtype = None + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + # Cast shadow params: + if device is not None: + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + self.collected_params = state_dict["collected_params"] assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" + # Cast collected params: + if device is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] |