aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md62
-rw-r--r--tests/test_ema.py51
-rw-r--r--torch_ema/ema.py31
3 files changed, 131 insertions, 13 deletions
diff --git a/README.md b/README.md
index a74db20..c9899ff 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,9 @@
# pytorch_ema
-A very small library for computing exponential moving averages of model
+A small library for computing exponential moving averages of model
parameters.
-This library was written for personal use. Nevertheless, if you run into issues
+This library was originally written for personal use. Nevertheless, if you run into issues
or have suggestions for improvement, feel free to open either a new issue or
pull request.
@@ -13,7 +13,9 @@ pull request.
pip install -U git+https://github.com/fadel/pytorch_ema
```
-## Example
+## Usage
+
+### Example
```python
import torch
@@ -38,7 +40,8 @@ for _ in range(20):
optimizer.zero_grad()
loss.backward()
optimizer.step()
- ema.update(model.parameters())
+ # Update the moving average with the new parameters from the last optimizer step
+ ema.update()
# Validation: original
model.eval()
@@ -47,13 +50,46 @@ loss = F.cross_entropy(logits, y_val)
print(loss.item())
# Validation: with EMA
-# First save original parameters before replacing with EMA version
-ema.store(model.parameters())
-# Copy EMA parameters to model
-ema.copy_to(model.parameters())
-logits = model(x_val)
-loss = F.cross_entropy(logits, y_val)
-print(loss.item())
-# Restore original parameters to resume training later
-ema.restore(model.parameters())
+# the .average_parameters() context manager
+# (1) saves original parameters before replacing with EMA version
+# (2) copies EMA parameters to model
+# (3) after exiting the `with`, restore original parameters to resume training later
+with ema.average_parameters():
+ logits = model(x_val)
+ loss = F.cross_entropy(logits, y_val)
+ print(loss.item())
+```
+
+### Manual validation mode
+
+While the `average_parameters()` context manager is convinient, you can also manually execute the same series of operations:
+```python
+ema.store()
+ema.copy_to()
+# ...
+ema.restore()
+```
+
+### Custom parameters
+
+By default the methods of `ExponentialMovingAverage` act on the model parameters the object was constructed with, but any compatable iterable of parameters can be passed to any method (such as `store()`, `copy_to()`, `update()`, `restore()`, and `average_parameters()`):
+```python
+model = torch.nn.Linear(10, 2)
+model2 = torch.nn.Linear(10, 2)
+ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
+# train
+# calling `ema.update()` will use `model.parameters()`
+ema.copy_to(model2)
+# model2 now contains the averaged weights
```
+
+### Resuming training
+
+Like a PyTorch optimizer, `ExponentialMovingAverage` objects have `state_dict()`/`load_state_dict()` methods to allow pausing, serializing, and restarting training without loosing shadow parameters, stored parameters, or the update count.
+
+### GPU/device support
+
+`ExponentialMovingAverage` objects have a `.to()` function (like `torch.Tensor`) that can move the object's internal state to a different device or floating-point dtype.
+
+
+For more details on individual methods, please check the docstrings. \ No newline at end of file
diff --git a/tests/test_ema.py b/tests/test_ema.py
index edcea4c..fa90a8c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -71,6 +71,57 @@ def test_val_error(decay, use_num_updates, explicit_params):
"Restored model wasn't the same as stored model"
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_contextmanager(explicit_params):
+ """Confirm that EMA validation error is lower than raw validation error."""
+ torch.manual_seed(0)
+ x_train = torch.rand((100, 10))
+ y_train = torch.rand(100).round().long()
+ x_val = torch.rand((100, 10))
+ y_val = torch.rand(100).round().long()
+ model = torch.nn.Linear(10, 2)
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=0.99,
+ )
+
+ # Train for a few epochs
+ model.train()
+ for _ in range(20):
+ logits = model(x_train)
+ loss = torch.nn.functional.cross_entropy(logits, y_train)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ if explicit_params:
+ ema.update(model.parameters())
+ else:
+ ema.update()
+
+ final_weight = model.weight.clone().detach()
+
+ # Validation: original
+ model.eval()
+ logits = model(x_val)
+ loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
+ print(f"Original loss: {loss_orig}")
+
+ # Validation: with EMA
+ if explicit_params:
+ cm = ema.average_parameters(model.parameters())
+ else:
+ cm = ema.average_parameters()
+
+ with cm:
+ logits = model(x_val)
+ loss_ema = torch.nn.functional.cross_entropy(logits, y_val)
+
+ print(f"EMA loss: {loss_ema}")
+ assert loss_ema < loss_orig, "EMA loss wasn't lower"
+ assert torch.all(model.weight == final_weight), "Restore failed"
+
+
@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
@pytest.mark.parametrize("use_num_updates", [True, False])
@pytest.mark.parametrize("explicit_params", [True, False])
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 3bcb465..1d03fd6 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -166,6 +166,37 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(c_param.data)
+ @contextlib.contextmanager
+ def average_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ):
+ r"""
+ Context manager for validation/inference with averaged parameters.
+
+ Equivalent to:
+
+ ema.store()
+ ema.copy_to()
+ try:
+ ...
+ finally:
+ ema.restore()
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.store(parameters)
+ self.copy_to(parameters)
+ try:
+ yield
+ finally:
+ self.restore(parameters)
+
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.