1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
|
import pytest
import copy
import torch
from torch_ema import ExponentialMovingAverage
@pytest.mark.parametrize("decay", [0.995])
@pytest.mark.parametrize("use_num_updates", [True, False])
@pytest.mark.parametrize("explicit_params", [True, False])
def test_state_dict(decay, use_num_updates, explicit_params):
model = torch.nn.Linear(10, 2, bias=False)
with torch.no_grad():
model.weight.fill_(0.0)
ema = ExponentialMovingAverage(
model.parameters(),
decay=decay,
use_num_updates=False
)
state_dict = copy.deepcopy(ema.state_dict())
model2 = torch.nn.Linear(10, 2, bias=False)
ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
ema2.load_state_dict(state_dict)
assert ema2.decay == decay
assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])
with torch.no_grad():
model2.weight.fill_(1.0)
if explicit_params:
ema2.update(model2.parameters())
else:
ema2.update()
assert torch.all(model2.weight == 1.0), "ema.update changed model weights"
ema.load_state_dict(ema2.state_dict())
if explicit_params:
ema.copy_to(model.parameters())
else:
ema.copy_to()
assert torch.allclose(
model.weight,
torch.full(size=(1,), fill_value=(1.0 - decay))
), "average was wrong"
def test_state_dict_types():
m1 = torch.nn.Linear(10, 2, bias=False)
m2 = torch.nn.Linear(10, 2, bias=False)
m2.to(torch.float16)
ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9)
ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9)
ema1.update()
ema2.update()
ema2.load_state_dict(ema1.state_dict())
ema1.copy_to()
ema2.copy_to()
assert m1.weight.dtype == torch.get_default_dtype()
assert m2.weight.dtype == torch.float16
assert torch.allclose(m1.weight.to(torch.float16), m2.weight)
def test_bad_state_dict1():
m = torch.nn.Linear(10, 2, bias=False)
ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
sd = ema.state_dict()
sd["shadow_params"][0] = torch.zeros(3, 7)
# it doesn't raise at loading, since it can't know shapes.
ema.load_state_dict(sd)
with pytest.raises(RuntimeError):
ema.copy_to()
# make sure it didn't change
assert torch.any(m.weight.abs() > 0)
def test_bad_state_dict2():
m = torch.nn.Linear(10, 2, bias=False)
ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
sd = ema.state_dict()
sd["shadow_params"] = sd["shadow_params"][:-1]
with pytest.raises(ValueError):
ema.load_state_dict(sd)
|