diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-07-02 13:22:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-02 13:22:55 +0200 |
commit | 72ac3d3333c3dc1d95eacdedbdb5a0132958973a (patch) | |
tree | 402dd92d2b1cf67aa7909c4eb87339b1a0acfc4c /tests/test_state_dict.py | |
parent | a8223abaad1da1293f350d80b636a8d67b2d58a5 (diff) | |
parent | e668ae1e0a757cf8217e926be9ae228676fbe17b (diff) |
Merge pull request #7 from Linux-cpp-lisp/device_and_dtype
Device and dtype improvements
Diffstat (limited to 'tests/test_state_dict.py')
-rw-r--r-- | tests/test_state_dict.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py new file mode 100644 index 0000000..814f446 --- /dev/null +++ b/tests/test_state_dict.py @@ -0,0 +1,85 @@ +import pytest + +import copy + +import torch + +from torch_ema import ExponentialMovingAverage + + +@pytest.mark.parametrize("decay", [0.995]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_state_dict(decay, use_num_updates, explicit_params): + model = torch.nn.Linear(10, 2, bias=False) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=False + ) + state_dict = copy.deepcopy(ema.state_dict()) + + model2 = torch.nn.Linear(10, 2, bias=False) + ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) + ema2.load_state_dict(state_dict) + assert ema2.decay == decay + assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) + + with torch.no_grad(): + model2.weight.fill_(1.0) + if explicit_params: + ema2.update(model2.parameters()) + else: + ema2.update() + assert torch.all(model2.weight == 1.0), "ema.update changed model weights" + + ema.load_state_dict(ema2.state_dict()) + + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() + assert torch.allclose( + model.weight, + torch.full(size=(1,), fill_value=(1.0 - decay)) + ), "average was wrong" + + +def test_state_dict_types(): + m1 = torch.nn.Linear(10, 2, bias=False) + m2 = torch.nn.Linear(10, 2, bias=False) + m2.to(torch.float16) + ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9) + ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9) + ema1.update() + ema2.update() + ema2.load_state_dict(ema1.state_dict()) + ema1.copy_to() + ema2.copy_to() + assert m1.weight.dtype == torch.get_default_dtype() + assert m2.weight.dtype == torch.float16 + assert torch.allclose(m1.weight.to(torch.float16), m2.weight) + + +def test_bad_state_dict1(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"][0] = torch.zeros(3, 7) + # it doesn't raise at loading, since it can't know shapes. + ema.load_state_dict(sd) + with pytest.raises(RuntimeError): + ema.copy_to() + # make sure it didn't change + assert torch.any(m.weight.abs() > 0) + + +def test_bad_state_dict2(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"] = sd["shadow_params"][:-1] + with pytest.raises(ValueError): + ema.load_state_dict(sd) |