diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2019-04-16 15:33:58 -0300 |
---|---|---|
committer | Samuel Fadel <samuelfadel@gmail.com> | 2019-04-16 15:33:58 -0300 |
commit | 09cfcf97e0e938a93867c7d445f1c9b4dcfea023 (patch) | |
tree | 0e048bd3affce4cc6b9d4e301729ed3a5c769731 /README.md |
Initial commit.
Diffstat (limited to 'README.md')
-rw-r--r-- | README.md | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/README.md b/README.md new file mode 100644 index 0000000..aeffbfe --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# pytorch_ema + +A very small library for computing exponential moving averages of model +parameters. + +This library was written for personal use. Nevertheless, if you run into issues +or have suggestions for improvement, feel free to open either a new issue or +pull request. + +## Example + +```python +import torch +import torch.nn.functional as F + +from torch_ema import ExponentialMovingAverage + + +x_train = torch.rand((100, 10)) +y_train = torch.rand(100).round().long() +model = torch.nn.Linear(10, 2) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) +ema = ExponentialMovingAverage(model.parameters(), decay=0.995) + +# Train for a few epochs +model.train() +for _ in range(10): + logits = model(x_train) + loss = F.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + ema.update(model.parameters()) + +# Compare losses: +# Original +model.eval() +logits = model(x_train) +loss = F.cross_entropy(logits, y_train) +print(loss.item()) + +# With EMA +ema.copy_to(model.parameters()) +logits = model(x_train) +loss = F.cross_entropy(logits, y_train) +print(loss.item()) +``` |