From 984e7d5fbc8eb916a24086a7053335516c14ff34 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 12:49:49 -0600 Subject: Cast state dict to current device/dtype --- torch_ema/ema.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'torch_ema/ema.py') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..f6d8f6e 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -178,15 +178,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + + # Consistant with torch.optim.Optimizer, cast things to current + # device and dtype + if len(self.shadow_params) > 0: + device = self.shadow_params[0].device + dtype = self.shadow_params[0].dtype + else: + device = None + dtype = None + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + # Cast shadow params: + if device is not None: + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + self.collected_params = state_dict["collected_params"] assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" + # Cast collected params: + if device is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] -- cgit v1.2.3