aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r--tests/test_ema.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index aa43b14..edcea4c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -134,3 +134,20 @@ def test_explicit_params():
ema.update(model2.parameters())
ema.copy_to()
assert not torch.all(model.weight == 0.0)
+
+
+def test_to():
+ m = torch.nn.Linear(11, 3)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ assert ema.shadow_params[0].dtype == torch.get_default_dtype()
+ ema.to(dtype=torch.float16)
+ assert ema.shadow_params[0].dtype == torch.float16
+ ema.store()
+ # we store whatever we get
+ assert ema.collected_params[0].dtype == torch.get_default_dtype()
+ m = m.to(torch.float16)
+ ema.store(m.parameters())
+ assert ema.collected_params[0].dtype == torch.float16
+ ema.to(dtype=torch.float64)
+ assert ema.collected_params[0].dtype == torch.float64
+ assert ema.shadow_params[0].dtype == torch.float64