aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ema.py
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 13:42:20 -0600
committerAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 13:42:20 -0600
commit81120309acff37307b2226fbda12277ca1662f93 (patch)
treeb2c9b1693bda9275e141b38f06052eafb91e0e0b /tests/test_ema.py
parent81a99ed1ec6f576d6b8004c7000ca0bc023e7483 (diff)
Add .to()
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r--tests/test_ema.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index aa43b14..edcea4c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -134,3 +134,20 @@ def test_explicit_params():
ema.update(model2.parameters())
ema.copy_to()
assert not torch.all(model.weight == 0.0)
+
+
+def test_to():
+ m = torch.nn.Linear(11, 3)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ assert ema.shadow_params[0].dtype == torch.get_default_dtype()
+ ema.to(dtype=torch.float16)
+ assert ema.shadow_params[0].dtype == torch.float16
+ ema.store()
+ # we store whatever we get
+ assert ema.collected_params[0].dtype == torch.get_default_dtype()
+ m = m.to(torch.float16)
+ ema.store(m.parameters())
+ assert ema.collected_params[0].dtype == torch.float16
+ ema.to(dtype=torch.float64)
+ assert ema.collected_params[0].dtype == torch.float64
+ assert ema.shadow_params[0].dtype == torch.float64