diff options
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r-- | tests/test_ema.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index aa43b14..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -134,3 +134,20 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) + + +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 |