aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py27
1 files changed, 27 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 6c0415f..f6d8f6e 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -178,15 +178,42 @@ class ExponentialMovingAverage:
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
+
+ # Consistant with torch.optim.Optimizer, cast things to current
+ # device and dtype
+ if len(self.shadow_params) > 0:
+ device = self.shadow_params[0].device
+ dtype = self.shadow_params[0].dtype
+ else:
+ device = None
+ dtype = None
+
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
+ # Cast shadow params:
+ if device is not None:
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+
self.collected_params = state_dict["collected_params"]
assert isinstance(self.collected_params, list), \
"collected_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.collected_params
), "collected_params must all be Tensors"
+ # Cast collected params:
+ if device is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]