From 09cfcf97e0e938a93867c7d445f1c9b4dcfea023 Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Tue, 16 Apr 2019 15:33:58 -0300 Subject: Initial commit. --- README.md | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 README.md (limited to 'README.md') diff --git a/README.md b/README.md new file mode 100644 index 0000000..aeffbfe --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# pytorch_ema + +A very small library for computing exponential moving averages of model +parameters. + +This library was written for personal use. Nevertheless, if you run into issues +or have suggestions for improvement, feel free to open either a new issue or +pull request. + +## Example + +```python +import torch +import torch.nn.functional as F + +from torch_ema import ExponentialMovingAverage + + +x_train = torch.rand((100, 10)) +y_train = torch.rand(100).round().long() +model = torch.nn.Linear(10, 2) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) +ema = ExponentialMovingAverage(model.parameters(), decay=0.995) + +# Train for a few epochs +model.train() +for _ in range(10): + logits = model(x_train) + loss = F.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + ema.update(model.parameters()) + +# Compare losses: +# Original +model.eval() +logits = model(x_train) +loss = F.cross_entropy(logits, y_train) +print(loss.item()) + +# With EMA +ema.copy_to(model.parameters()) +logits = model(x_train) +loss = F.cross_entropy(logits, y_train) +print(loss.item()) +``` -- cgit v1.2.3