diff options
author | Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-04-20 12:03:13 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-20 20:03:13 +0200 |
commit | 1a47d1c99784d1247f41308d335b2087de13334c (patch) | |
tree | 78803311cbf29ac47cac1a107e121a20d0af73a7 | |
parent | 3950a7b5c4b88f46fd14f620277bad21898597a9 (diff) |
Tests and miscellaneous (#4)
* linter
* Add tests
* Add restore test
* Type hints
* PyTorch dependency
* One fewer temp buffer
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | tests/test_ema.py | 96 | ||||
-rw-r--r-- | torch_ema/ema.py | 50 |
3 files changed, 128 insertions, 20 deletions
@@ -4,7 +4,7 @@ __version__ = '0.2' url = 'https://github.com/fadel/pytorch_ema' download_url = '{}/archive/{}.tar.gz'.format(url, __version__) -install_requires = [] +install_requires = ["torch"] setup_requires = [] tests_require = [] diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 0000000..6d7e43e --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,96 @@ +import pytest + +import torch + +from torch_ema import ExponentialMovingAverage + + +@pytest.mark.parametrize("decay", [0.995, 0.9]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +def test_val_error(decay, use_num_updates): + """Confirm that EMA validation error is lower than raw validation error.""" + torch.manual_seed(0) + x_train = torch.rand((100, 10)) + y_train = torch.rand(100).round().long() + x_val = torch.rand((100, 10)) + y_val = torch.rand(100).round().long() + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=use_num_updates + ) + + # Train for a few epochs + model.train() + for _ in range(20): + logits = model(x_train) + loss = torch.nn.functional.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + ema.update(model.parameters()) + + # Validation: original + model.eval() + logits = model(x_val) + loss_orig = torch.nn.functional.cross_entropy(logits, y_val) + print(f"Original loss: {loss_orig}") + + # Validation: with EMA + # First save original parameters before replacing with EMA version + ema.store(model.parameters()) + # Copy EMA parameters to model + ema.copy_to(model.parameters()) + logits = model(x_val) + loss_ema = torch.nn.functional.cross_entropy(logits, y_val) + + print(f"EMA loss: {loss_ema}") + assert loss_ema < loss_orig, "EMA loss wasn't lower" + + # Test restore + ema.restore(model.parameters()) + model.eval() + logits = model(x_val) + loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val) + assert torch.allclose(loss_orig, loss_orig2), \ + "Restored model wasn't the same as stored model" + + +@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +def test_store_restore(decay, use_num_updates): + model = torch.nn.Linear(10, 2) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=use_num_updates + ) + orig_weight = model.weight.clone().detach() + ema.store(model.parameters()) + with torch.no_grad(): + model.weight.uniform_(0.0, 1.0) + ema.restore(model.parameters()) + assert torch.all(model.weight == orig_weight) + + +@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) +def test_update(decay): + model = torch.nn.Linear(10, 2, bias=False) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=False + ) + with torch.no_grad(): + model.weight.fill_(1.0) + ema.update(model.parameters()) + assert torch.all(model.weight == 1.0), "ema.update changed model weights" + ema.copy_to(model.parameters()) + assert torch.allclose( + model.weight, + torch.full(size=(1,), fill_value=(1.0 - decay)) + ), "average was wrong" diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 7771ef7..0233c78 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -1,23 +1,30 @@ from __future__ import division from __future__ import unicode_literals +from typing import Iterable + import torch -# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +# Partially based on: +# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py class ExponentialMovingAverage: """ Maintains (exponential) moving average of a set of parameters. - """ - def __init__(self, parameters, decay, use_num_updates=True): - """ - Args: - parameters: Iterable of `torch.nn.Parameter`; usually the result of + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the result of `model.parameters()`. - decay: The exponential decay. - use_num_updates: Whether to use number of updates when computing + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing averages. - """ + """ + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float, + use_num_updates: bool = True + ): if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.decay = decay @@ -26,7 +33,7 @@ class ExponentialMovingAverage: for p in parameters if p.requires_grad] self.collected_params = [] - def update(self, parameters): + def update(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Update currently maintained parameters. @@ -40,18 +47,24 @@ class ExponentialMovingAverage: decay = self.decay if self.num_updates is not None: self.num_updates += 1 - decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) + decay = min( + decay, + (1 + self.num_updates) / (10 + self.num_updates) + ) one_minus_decay = 1.0 - decay with torch.no_grad(): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): - s_param.sub_(one_minus_decay * (s_param - param)) + tmp = (s_param - param) + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param.sub_(tmp) - def copy_to(self, parameters): + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current parameters into given collection of parameters. - Args: + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ @@ -59,11 +72,11 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(s_param.data) - def store(self, parameters): + def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Save the current parameters for restoring later. - Args: + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ @@ -71,7 +84,7 @@ class ExponentialMovingAverage: for param in parameters if param.requires_grad] - def restore(self, parameters): + def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the @@ -79,11 +92,10 @@ class ExponentialMovingAverage: `copy_to` method. After validation (or model saving), use this to restore the former parameters. - Args: + Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) - |