aboutsummaryrefslogtreecommitdiff
path: root/torch_ema
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 12:49:49 -0600
committerAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 12:49:49 -0600
commit984e7d5fbc8eb916a24086a7053335516c14ff34 (patch)
tree4aa29075097638a6b19e4e561ccfe62d965ea7a3 /torch_ema
parenta8223abaad1da1293f350d80b636a8d67b2d58a5 (diff)
Cast state dict to current device/dtype
Diffstat (limited to 'torch_ema')
-rw-r--r--torch_ema/ema.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 6c0415f..f6d8f6e 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -178,15 +178,42 @@ class ExponentialMovingAverage:
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
+
+ # Consistant with torch.optim.Optimizer, cast things to current
+ # device and dtype
+ if len(self.shadow_params) > 0:
+ device = self.shadow_params[0].device
+ dtype = self.shadow_params[0].dtype
+ else:
+ device = None
+ dtype = None
+
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
+ # Cast shadow params:
+ if device is not None:
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+
self.collected_params = state_dict["collected_params"]
assert isinstance(self.collected_params, list), \
"collected_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.collected_params
), "collected_params must all be Tensors"
+ # Cast collected params:
+ if device is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]