diff options
author | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 13:35:20 -0600 |
---|---|---|
committer | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-21 13:35:20 -0600 |
commit | 81a99ed1ec6f576d6b8004c7000ca0bc023e7483 (patch) | |
tree | 16ff28a504c2285eb3baea3afc72f39c2efffe86 /tests/test_ema.py | |
parent | bf6d797c31b35b846c072618c2c8631feeb6db38 (diff) |
More state_dict tests
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r-- | tests/test_ema.py | 42 |
1 files changed, 0 insertions, 42 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index 67a14dc..aa43b14 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -1,7 +1,5 @@ import pytest -import copy - import torch from torch_ema import ExponentialMovingAverage @@ -136,43 +134,3 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) - - -@pytest.mark.parametrize("decay", [0.995]) -@pytest.mark.parametrize("use_num_updates", [True, False]) -@pytest.mark.parametrize("explicit_params", [True, False]) -def test_state_dict(decay, use_num_updates, explicit_params): - model = torch.nn.Linear(10, 2, bias=False) - with torch.no_grad(): - model.weight.fill_(0.0) - ema = ExponentialMovingAverage( - model.parameters(), - decay=decay, - use_num_updates=False - ) - state_dict = copy.deepcopy(ema.state_dict()) - - model2 = torch.nn.Linear(10, 2, bias=False) - ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) - ema2.load_state_dict(state_dict) - assert ema2.decay == decay - assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) - - with torch.no_grad(): - model2.weight.fill_(1.0) - if explicit_params: - ema2.update(model2.parameters()) - else: - ema2.update() - assert torch.all(model2.weight == 1.0), "ema.update changed model weights" - - ema.load_state_dict(ema2.state_dict()) - - if explicit_params: - ema.copy_to(model.parameters()) - else: - ema.copy_to() - assert torch.allclose( - model.weight, - torch.full(size=(1,), fill_value=(1.0 - decay)) - ), "average was wrong" |