From 81a99ed1ec6f576d6b8004c7000ca0bc023e7483 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:35:20 -0600 Subject: More state_dict tests --- tests/test_ema.py | 42 ------------------------ tests/test_state_dict.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 42 deletions(-) create mode 100644 tests/test_state_dict.py (limited to 'tests') diff --git a/tests/test_ema.py b/tests/test_ema.py index 67a14dc..aa43b14 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -1,7 +1,5 @@ import pytest -import copy - import torch from torch_ema import ExponentialMovingAverage @@ -136,43 +134,3 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) - - -@pytest.mark.parametrize("decay", [0.995]) -@pytest.mark.parametrize("use_num_updates", [True, False]) -@pytest.mark.parametrize("explicit_params", [True, False]) -def test_state_dict(decay, use_num_updates, explicit_params): - model = torch.nn.Linear(10, 2, bias=False) - with torch.no_grad(): - model.weight.fill_(0.0) - ema = ExponentialMovingAverage( - model.parameters(), - decay=decay, - use_num_updates=False - ) - state_dict = copy.deepcopy(ema.state_dict()) - - model2 = torch.nn.Linear(10, 2, bias=False) - ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) - ema2.load_state_dict(state_dict) - assert ema2.decay == decay - assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) - - with torch.no_grad(): - model2.weight.fill_(1.0) - if explicit_params: - ema2.update(model2.parameters()) - else: - ema2.update() - assert torch.all(model2.weight == 1.0), "ema.update changed model weights" - - ema.load_state_dict(ema2.state_dict()) - - if explicit_params: - ema.copy_to(model.parameters()) - else: - ema.copy_to() - assert torch.allclose( - model.weight, - torch.full(size=(1,), fill_value=(1.0 - decay)) - ), "average was wrong" diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py new file mode 100644 index 0000000..814f446 --- /dev/null +++ b/tests/test_state_dict.py @@ -0,0 +1,85 @@ +import pytest + +import copy + +import torch + +from torch_ema import ExponentialMovingAverage + + +@pytest.mark.parametrize("decay", [0.995]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_state_dict(decay, use_num_updates, explicit_params): + model = torch.nn.Linear(10, 2, bias=False) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=False + ) + state_dict = copy.deepcopy(ema.state_dict()) + + model2 = torch.nn.Linear(10, 2, bias=False) + ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) + ema2.load_state_dict(state_dict) + assert ema2.decay == decay + assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) + + with torch.no_grad(): + model2.weight.fill_(1.0) + if explicit_params: + ema2.update(model2.parameters()) + else: + ema2.update() + assert torch.all(model2.weight == 1.0), "ema.update changed model weights" + + ema.load_state_dict(ema2.state_dict()) + + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() + assert torch.allclose( + model.weight, + torch.full(size=(1,), fill_value=(1.0 - decay)) + ), "average was wrong" + + +def test_state_dict_types(): + m1 = torch.nn.Linear(10, 2, bias=False) + m2 = torch.nn.Linear(10, 2, bias=False) + m2.to(torch.float16) + ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9) + ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9) + ema1.update() + ema2.update() + ema2.load_state_dict(ema1.state_dict()) + ema1.copy_to() + ema2.copy_to() + assert m1.weight.dtype == torch.get_default_dtype() + assert m2.weight.dtype == torch.float16 + assert torch.allclose(m1.weight.to(torch.float16), m2.weight) + + +def test_bad_state_dict1(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"][0] = torch.zeros(3, 7) + # it doesn't raise at loading, since it can't know shapes. + ema.load_state_dict(sd) + with pytest.raises(RuntimeError): + ema.copy_to() + # make sure it didn't change + assert torch.any(m.weight.abs() > 0) + + +def test_bad_state_dict2(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"] = sd["shadow_params"][:-1] + with pytest.raises(ValueError): + ema.load_state_dict(sd) -- cgit v1.2.3 From 81120309acff37307b2226fbda12277ca1662f93 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:42:20 -0600 Subject: Add .to() --- tests/test_ema.py | 17 +++++++++++++++++ torch_ema/ema.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+) (limited to 'tests') diff --git a/tests/test_ema.py b/tests/test_ema.py index aa43b14..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -134,3 +134,20 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) + + +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 diff --git a/torch_ema/ema.py b/torch_ema/ema.py index b3487cf..2aa3004 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -165,6 +165,28 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: -- cgit v1.2.3