aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/test_ema.py51
-rw-r--r--torch_ema/ema.py31
2 files changed, 82 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index edcea4c..fa90a8c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -71,6 +71,57 @@ def test_val_error(decay, use_num_updates, explicit_params):
"Restored model wasn't the same as stored model"
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_contextmanager(explicit_params):
+ """Confirm that EMA validation error is lower than raw validation error."""
+ torch.manual_seed(0)
+ x_train = torch.rand((100, 10))
+ y_train = torch.rand(100).round().long()
+ x_val = torch.rand((100, 10))
+ y_val = torch.rand(100).round().long()
+ model = torch.nn.Linear(10, 2)
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=0.99,
+ )
+
+ # Train for a few epochs
+ model.train()
+ for _ in range(20):
+ logits = model(x_train)
+ loss = torch.nn.functional.cross_entropy(logits, y_train)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if explicit_params:
+ ema.update(model.parameters())
+ else:
+ ema.update()
+
+ final_weight = model.weight.clone().detach()
+
+ # Validation: original
+ model.eval()
+ logits = model(x_val)
+ loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
+ print(f"Original loss: {loss_orig}")
+
+ # Validation: with EMA
+ if explicit_params:
+ cm = ema.average_parameters(model.parameters())
+ else:
+ cm = ema.average_parameters()
+
+ with cm:
+ logits = model(x_val)
+ loss_ema = torch.nn.functional.cross_entropy(logits, y_val)
+
+ print(f"EMA loss: {loss_ema}")
+ assert loss_ema < loss_orig, "EMA loss wasn't lower"
+ assert torch.all(model.weight == final_weight), "Restore failed"
+
+
@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
@pytest.mark.parametrize("use_num_updates", [True, False])
@pytest.mark.parametrize("explicit_params", [True, False])
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 3bcb465..1d03fd6 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -166,6 +166,37 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(c_param.data)
+ @contextlib.contextmanager
+ def average_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ):
+ r"""
+ Context manager for validation/inference with averaged parameters.
+
+ Equivalent to:
+
+ ema.store()
+ ema.copy_to()
+ try:
+ ...
+ finally:
+ ema.restore()
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.store(parameters)
+ self.copy_to(parameters)
+ try:
+ yield
+ finally:
+ self.restore(parameters)
+
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.