From 5bf836aba916fa368bed96b56bc5a5048e8b99bd Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 14:31:58 -0600 Subject: context manager for validation --- tests/test_ema.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ torch_ema/ema.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) diff --git a/tests/test_ema.py b/tests/test_ema.py index edcea4c..fa90a8c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -71,6 +71,57 @@ def test_val_error(decay, use_num_updates, explicit_params): "Restored model wasn't the same as stored model" +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_contextmanager(explicit_params): + """Confirm that EMA validation error is lower than raw validation error.""" + torch.manual_seed(0) + x_train = torch.rand((100, 10)) + y_train = torch.rand(100).round().long() + x_val = torch.rand((100, 10)) + y_val = torch.rand(100).round().long() + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + ema = ExponentialMovingAverage( + model.parameters(), + decay=0.99, + ) + + # Train for a few epochs + model.train() + for _ in range(20): + logits = model(x_train) + loss = torch.nn.functional.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() + + final_weight = model.weight.clone().detach() + + # Validation: original + model.eval() + logits = model(x_val) + loss_orig = torch.nn.functional.cross_entropy(logits, y_val) + print(f"Original loss: {loss_orig}") + + # Validation: with EMA + if explicit_params: + cm = ema.average_parameters(model.parameters()) + else: + cm = ema.average_parameters() + + with cm: + logits = model(x_val) + loss_ema = torch.nn.functional.cross_entropy(logits, y_val) + + print(f"EMA loss: {loss_ema}") + assert loss_ema < loss_orig, "EMA loss wasn't lower" + assert torch.all(model.weight == final_weight), "Restore failed" + + @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) @pytest.mark.parametrize("use_num_updates", [True, False]) @pytest.mark.parametrize("explicit_params", [True, False]) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 3bcb465..1d03fd6 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -166,6 +166,37 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. -- cgit v1.2.3