aboutsummaryrefslogtreecommitdiff
path: root/tests/test_state_dict.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state_dict.py')
-rw-r--r--tests/test_state_dict.py85
1 files changed, 85 insertions, 0 deletions
diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py
new file mode 100644
index 0000000..814f446
--- /dev/null
+++ b/tests/test_state_dict.py
@@ -0,0 +1,85 @@
+import pytest
+
+import copy
+
+import torch
+
+from torch_ema import ExponentialMovingAverage
+
+
+@pytest.mark.parametrize("decay", [0.995])
+@pytest.mark.parametrize("use_num_updates", [True, False])
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_state_dict(decay, use_num_updates, explicit_params):
+ model = torch.nn.Linear(10, 2, bias=False)
+ with torch.no_grad():
+ model.weight.fill_(0.0)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=decay,
+ use_num_updates=False
+ )
+ state_dict = copy.deepcopy(ema.state_dict())
+
+ model2 = torch.nn.Linear(10, 2, bias=False)
+ ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
+ ema2.load_state_dict(state_dict)
+ assert ema2.decay == decay
+ assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])
+
+ with torch.no_grad():
+ model2.weight.fill_(1.0)
+ if explicit_params:
+ ema2.update(model2.parameters())
+ else:
+ ema2.update()
+ assert torch.all(model2.weight == 1.0), "ema.update changed model weights"
+
+ ema.load_state_dict(ema2.state_dict())
+
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
+ assert torch.allclose(
+ model.weight,
+ torch.full(size=(1,), fill_value=(1.0 - decay))
+ ), "average was wrong"
+
+
+def test_state_dict_types():
+ m1 = torch.nn.Linear(10, 2, bias=False)
+ m2 = torch.nn.Linear(10, 2, bias=False)
+ m2.to(torch.float16)
+ ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9)
+ ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9)
+ ema1.update()
+ ema2.update()
+ ema2.load_state_dict(ema1.state_dict())
+ ema1.copy_to()
+ ema2.copy_to()
+ assert m1.weight.dtype == torch.get_default_dtype()
+ assert m2.weight.dtype == torch.float16
+ assert torch.allclose(m1.weight.to(torch.float16), m2.weight)
+
+
+def test_bad_state_dict1():
+ m = torch.nn.Linear(10, 2, bias=False)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ sd = ema.state_dict()
+ sd["shadow_params"][0] = torch.zeros(3, 7)
+ # it doesn't raise at loading, since it can't know shapes.
+ ema.load_state_dict(sd)
+ with pytest.raises(RuntimeError):
+ ema.copy_to()
+ # make sure it didn't change
+ assert torch.any(m.weight.abs() > 0)
+
+
+def test_bad_state_dict2():
+ m = torch.nn.Linear(10, 2, bias=False)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ sd = ema.state_dict()
+ sd["shadow_params"] = sd["shadow_params"][:-1]
+ with pytest.raises(ValueError):
+ ema.load_state_dict(sd)