diff options
-rw-r--r-- | README.md | 62 | ||||
-rw-r--r-- | tests/test_ema.py | 51 | ||||
-rw-r--r-- | torch_ema/ema.py | 31 |
3 files changed, 131 insertions, 13 deletions
@@ -1,9 +1,9 @@ # pytorch_ema -A very small library for computing exponential moving averages of model +A small library for computing exponential moving averages of model parameters. -This library was written for personal use. Nevertheless, if you run into issues +This library was originally written for personal use. Nevertheless, if you run into issues or have suggestions for improvement, feel free to open either a new issue or pull request. @@ -13,7 +13,9 @@ pull request. pip install -U git+https://github.com/fadel/pytorch_ema ``` -## Example +## Usage + +### Example ```python import torch @@ -38,7 +40,8 @@ for _ in range(20): optimizer.zero_grad() loss.backward() optimizer.step() - ema.update(model.parameters()) + # Update the moving average with the new parameters from the last optimizer step + ema.update() # Validation: original model.eval() @@ -47,13 +50,46 @@ loss = F.cross_entropy(logits, y_val) print(loss.item()) # Validation: with EMA -# First save original parameters before replacing with EMA version -ema.store(model.parameters()) -# Copy EMA parameters to model -ema.copy_to(model.parameters()) -logits = model(x_val) -loss = F.cross_entropy(logits, y_val) -print(loss.item()) -# Restore original parameters to resume training later -ema.restore(model.parameters()) +# the .average_parameters() context manager +# (1) saves original parameters before replacing with EMA version +# (2) copies EMA parameters to model +# (3) after exiting the `with`, restore original parameters to resume training later +with ema.average_parameters(): + logits = model(x_val) + loss = F.cross_entropy(logits, y_val) + print(loss.item()) +``` + +### Manual validation mode + +While the `average_parameters()` context manager is convinient, you can also manually execute the same series of operations: +```python +ema.store() +ema.copy_to() +# ... +ema.restore() +``` + +### Custom parameters + +By default the methods of `ExponentialMovingAverage` act on the model parameters the object was constructed with, but any compatable iterable of parameters can be passed to any method (such as `store()`, `copy_to()`, `update()`, `restore()`, and `average_parameters()`): +```python +model = torch.nn.Linear(10, 2) +model2 = torch.nn.Linear(10, 2) +ema = ExponentialMovingAverage(model.parameters(), decay=0.995) +# train +# calling `ema.update()` will use `model.parameters()` +ema.copy_to(model2) +# model2 now contains the averaged weights ``` + +### Resuming training + +Like a PyTorch optimizer, `ExponentialMovingAverage` objects have `state_dict()`/`load_state_dict()` methods to allow pausing, serializing, and restarting training without loosing shadow parameters, stored parameters, or the update count. + +### GPU/device support + +`ExponentialMovingAverage` objects have a `.to()` function (like `torch.Tensor`) that can move the object's internal state to a different device or floating-point dtype. + + +For more details on individual methods, please check the docstrings.
\ No newline at end of file diff --git a/tests/test_ema.py b/tests/test_ema.py index edcea4c..fa90a8c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -71,6 +71,57 @@ def test_val_error(decay, use_num_updates, explicit_params): "Restored model wasn't the same as stored model" +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_contextmanager(explicit_params): + """Confirm that EMA validation error is lower than raw validation error.""" + torch.manual_seed(0) + x_train = torch.rand((100, 10)) + y_train = torch.rand(100).round().long() + x_val = torch.rand((100, 10)) + y_val = torch.rand(100).round().long() + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + ema = ExponentialMovingAverage( + model.parameters(), + decay=0.99, + ) + + # Train for a few epochs + model.train() + for _ in range(20): + logits = model(x_train) + loss = torch.nn.functional.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() + + final_weight = model.weight.clone().detach() + + # Validation: original + model.eval() + logits = model(x_val) + loss_orig = torch.nn.functional.cross_entropy(logits, y_val) + print(f"Original loss: {loss_orig}") + + # Validation: with EMA + if explicit_params: + cm = ema.average_parameters(model.parameters()) + else: + cm = ema.average_parameters() + + with cm: + logits = model(x_val) + loss_ema = torch.nn.functional.cross_entropy(logits, y_val) + + print(f"EMA loss: {loss_ema}") + assert loss_ema < loss_orig, "EMA loss wasn't lower" + assert torch.all(model.weight == final_weight), "Restore failed" + + @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) @pytest.mark.parametrize("use_num_updates", [True, False]) @pytest.mark.parametrize("explicit_params", [True, False]) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 3bcb465..1d03fd6 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -166,6 +166,37 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. |