aboutsummaryrefslogtreecommitdiff
import pytest

import copy

import torch

from torch_ema import ExponentialMovingAverage


@pytest.mark.parametrize("decay", [0.995])
@pytest.mark.parametrize("use_num_updates", [True, False])
@pytest.mark.parametrize("explicit_params", [True, False])
def test_state_dict(decay, use_num_updates, explicit_params):
    model = torch.nn.Linear(10, 2, bias=False)
    with torch.no_grad():
        model.weight.fill_(0.0)
    ema = ExponentialMovingAverage(
        model.parameters(),
        decay=decay,
        use_num_updates=False
    )
    state_dict = copy.deepcopy(ema.state_dict())

    model2 = torch.nn.Linear(10, 2, bias=False)
    ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
    ema2.load_state_dict(state_dict)
    assert ema2.decay == decay
    assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])

    with torch.no_grad():
        model2.weight.fill_(1.0)
    if explicit_params:
        ema2.update(model2.parameters())
    else:
        ema2.update()
    assert torch.all(model2.weight == 1.0), "ema.update changed model weights"

    ema.load_state_dict(ema2.state_dict())

    if explicit_params:
        ema.copy_to(model.parameters())
    else:
        ema.copy_to()
    assert torch.allclose(
        model.weight,
        torch.full(size=(1,), fill_value=(1.0 - decay))
    ), "average was wrong"


def test_state_dict_types():
    m1 = torch.nn.Linear(10, 2, bias=False)
    m2 = torch.nn.Linear(10, 2, bias=False)
    m2.to(torch.float16)
    ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9)
    ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9)
    ema1.update()
    ema2.update()
    ema2.load_state_dict(ema1.state_dict())
    ema1.copy_to()
    ema2.copy_to()
    assert m1.weight.dtype == torch.get_default_dtype()
    assert m2.weight.dtype == torch.float16
    assert torch.allclose(m1.weight.to(torch.float16), m2.weight)


def test_bad_state_dict1():
    m = torch.nn.Linear(10, 2, bias=False)
    ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
    sd = ema.state_dict()
    sd["shadow_params"][0] = torch.zeros(3, 7)
    # it doesn't raise at loading, since it can't know shapes.
    ema.load_state_dict(sd)
    with pytest.raises(RuntimeError):
        ema.copy_to()
    # make sure it didn't change
    assert torch.any(m.weight.abs() > 0)


def test_bad_state_dict2():
    m = torch.nn.Linear(10, 2, bias=False)
    ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
    sd = ema.state_dict()
    sd["shadow_params"] = sd["shadow_params"][:-1]
    with pytest.raises(ValueError):
        ema.load_state_dict(sd)