aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ema.py
blob: 6d7e43e58a72be6d5d2b11a4e6b50683255bb7f9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import pytest

import torch

from torch_ema import ExponentialMovingAverage


@pytest.mark.parametrize("decay", [0.995, 0.9])
@pytest.mark.parametrize("use_num_updates", [True, False])
def test_val_error(decay, use_num_updates):
    """Confirm that EMA validation error is lower than raw validation error."""
    torch.manual_seed(0)
    x_train = torch.rand((100, 10))
    y_train = torch.rand(100).round().long()
    x_val = torch.rand((100, 10))
    y_val = torch.rand(100).round().long()
    model = torch.nn.Linear(10, 2)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=use_num_updates
    )

    # Train for a few epochs
    model.train()
    for _ in range(20):
        logits = model(x_train)
        loss = torch.nn.functional.cross_entropy(logits, y_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.update(model.parameters())

    # Validation: original
    model.eval()
    logits = model(x_val)
    loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
    print(f"Original loss: {loss_orig}")

    # Validation: with EMA
    # First save original parameters before replacing with EMA version
    ema.store(model.parameters())
    # Copy EMA parameters to model
    ema.copy_to(model.parameters())
    logits = model(x_val)
    loss_ema = torch.nn.functional.cross_entropy(logits, y_val)

    print(f"EMA loss: {loss_ema}")
    assert loss_ema < loss_orig, "EMA loss wasn't lower"

    # Test restore
    ema.restore(model.parameters())
    model.eval()
    logits = model(x_val)
    loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val)
    assert torch.allclose(loss_orig, loss_orig2), \
        "Restored model wasn't the same as stored model"


@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
@pytest.mark.parametrize("use_num_updates", [True, False])
def test_store_restore(decay, use_num_updates):
    model = torch.nn.Linear(10, 2)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=use_num_updates
    )
    orig_weight = model.weight.clone().detach()
    ema.store(model.parameters())
    with torch.no_grad():
        model.weight.uniform_(0.0, 1.0)
    ema.restore(model.parameters())
    assert torch.all(model.weight == orig_weight)


@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
def test_update(decay):
    model = torch.nn.Linear(10, 2, bias=False)
    with torch.no_grad():
        model.weight.fill_(0.0)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=False
    )
    with torch.no_grad():
        model.weight.fill_(1.0)
    ema.update(model.parameters())
    assert torch.all(model.weight == 1.0), "ema.update changed model weights"
    ema.copy_to(model.parameters())
    assert torch.allclose(
        model.weight,
        torch.full(size=(1,), fill_value=(1.0 - decay))
    ), "average was wrong"