diff options
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..f6d8f6e 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -178,15 +178,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + + # Consistant with torch.optim.Optimizer, cast things to current + # device and dtype + if len(self.shadow_params) > 0: + device = self.shadow_params[0].device + dtype = self.shadow_params[0].dtype + else: + device = None + dtype = None + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + # Cast shadow params: + if device is not None: + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + self.collected_params = state_dict["collected_params"] assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" + # Cast collected params: + if device is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] |