aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r--tests/test_ema.py55
1 files changed, 15 insertions, 40 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index 67a14dc..edcea4c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -1,7 +1,5 @@
import pytest
-import copy
-
import torch
from torch_ema import ExponentialMovingAverage
@@ -138,41 +136,18 @@ def test_explicit_params():
assert not torch.all(model.weight == 0.0)
-@pytest.mark.parametrize("decay", [0.995])
-@pytest.mark.parametrize("use_num_updates", [True, False])
-@pytest.mark.parametrize("explicit_params", [True, False])
-def test_state_dict(decay, use_num_updates, explicit_params):
- model = torch.nn.Linear(10, 2, bias=False)
- with torch.no_grad():
- model.weight.fill_(0.0)
- ema = ExponentialMovingAverage(
- model.parameters(),
- decay=decay,
- use_num_updates=False
- )
- state_dict = copy.deepcopy(ema.state_dict())
-
- model2 = torch.nn.Linear(10, 2, bias=False)
- ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
- ema2.load_state_dict(state_dict)
- assert ema2.decay == decay
- assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])
-
- with torch.no_grad():
- model2.weight.fill_(1.0)
- if explicit_params:
- ema2.update(model2.parameters())
- else:
- ema2.update()
- assert torch.all(model2.weight == 1.0), "ema.update changed model weights"
-
- ema.load_state_dict(ema2.state_dict())
-
- if explicit_params:
- ema.copy_to(model.parameters())
- else:
- ema.copy_to()
- assert torch.allclose(
- model.weight,
- torch.full(size=(1,), fill_value=(1.0 - decay))
- ), "average was wrong"
+def test_to():
+ m = torch.nn.Linear(11, 3)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ assert ema.shadow_params[0].dtype == torch.get_default_dtype()
+ ema.to(dtype=torch.float16)
+ assert ema.shadow_params[0].dtype == torch.float16
+ ema.store()
+ # we store whatever we get
+ assert ema.collected_params[0].dtype == torch.get_default_dtype()
+ m = m.to(torch.float16)
+ ema.store(m.parameters())
+ assert ema.collected_params[0].dtype == torch.float16
+ ema.to(dtype=torch.float64)
+ assert ema.collected_params[0].dtype == torch.float64
+ assert ema.shadow_params[0].dtype == torch.float64