aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--README.md62
1 files changed, 49 insertions, 13 deletions
diff --git a/README.md b/README.md
index a74db20..c9899ff 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,9 @@
# pytorch_ema
-A very small library for computing exponential moving averages of model
+A small library for computing exponential moving averages of model
parameters.
-This library was written for personal use. Nevertheless, if you run into issues
+This library was originally written for personal use. Nevertheless, if you run into issues
or have suggestions for improvement, feel free to open either a new issue or
pull request.
@@ -13,7 +13,9 @@ pull request.
pip install -U git+https://github.com/fadel/pytorch_ema
```
-## Example
+## Usage
+
+### Example
```python
import torch
@@ -38,7 +40,8 @@ for _ in range(20):
optimizer.zero_grad()
loss.backward()
optimizer.step()
- ema.update(model.parameters())
+ # Update the moving average with the new parameters from the last optimizer step
+ ema.update()
# Validation: original
model.eval()
@@ -47,13 +50,46 @@ loss = F.cross_entropy(logits, y_val)
print(loss.item())
# Validation: with EMA
-# First save original parameters before replacing with EMA version
-ema.store(model.parameters())
-# Copy EMA parameters to model
-ema.copy_to(model.parameters())
-logits = model(x_val)
-loss = F.cross_entropy(logits, y_val)
-print(loss.item())
-# Restore original parameters to resume training later
-ema.restore(model.parameters())
+# the .average_parameters() context manager
+# (1) saves original parameters before replacing with EMA version
+# (2) copies EMA parameters to model
+# (3) after exiting the `with`, restore original parameters to resume training later
+with ema.average_parameters():
+ logits = model(x_val)
+ loss = F.cross_entropy(logits, y_val)
+ print(loss.item())
+```
+
+### Manual validation mode
+
+While the `average_parameters()` context manager is convinient, you can also manually execute the same series of operations:
+```python
+ema.store()
+ema.copy_to()
+# ...
+ema.restore()
+```
+
+### Custom parameters
+
+By default the methods of `ExponentialMovingAverage` act on the model parameters the object was constructed with, but any compatable iterable of parameters can be passed to any method (such as `store()`, `copy_to()`, `update()`, `restore()`, and `average_parameters()`):
+```python
+model = torch.nn.Linear(10, 2)
+model2 = torch.nn.Linear(10, 2)
+ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
+# train
+# calling `ema.update()` will use `model.parameters()`
+ema.copy_to(model2)
+# model2 now contains the averaged weights
```
+
+### Resuming training
+
+Like a PyTorch optimizer, `ExponentialMovingAverage` objects have `state_dict()`/`load_state_dict()` methods to allow pausing, serializing, and restarting training without loosing shadow parameters, stored parameters, or the update count.
+
+### GPU/device support
+
+`ExponentialMovingAverage` objects have a `.to()` function (like `torch.Tensor`) that can move the object's internal state to a different device or floating-point dtype.
+
+
+For more details on individual methods, please check the docstrings. \ No newline at end of file