diff options
author | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 13:42:20 -0600 |
---|---|---|
committer | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 13:42:20 -0600 |
commit | 81120309acff37307b2226fbda12277ca1662f93 (patch) | |
tree | b2c9b1693bda9275e141b38f06052eafb91e0e0b /tests | |
parent | 81a99ed1ec6f576d6b8004c7000ca0bc023e7483 (diff) |
Add .to()
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_ema.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index aa43b14..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -134,3 +134,20 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) + + +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 |