diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-03-03 13:38:52 +0100 |
---|---|---|
committer | Samuel Fadel <samuelfadel@gmail.com> | 2021-03-03 13:38:52 +0100 |
commit | 30fb07f4d277fe70cd7596c9be98faf3c30f52fc (patch) | |
tree | 8e7db763f6c5cfda8c666a1f584456f6c2a7f434 | |
parent | 18f4ea37fe6993b0e3a4a45d28421d1e223cfeac (diff) |
Updated README.md with store/restore functionality.
-rw-r--r-- | README.md | 24 |
1 files changed, 15 insertions, 9 deletions
@@ -21,16 +21,18 @@ import torch.nn.functional as F from torch_ema import ExponentialMovingAverage - +torch.manual_seed(0) x_train = torch.rand((100, 10)) y_train = torch.rand(100).round().long() +x_val = torch.rand((100, 10)) +y_val = torch.rand(100).round().long() model = torch.nn.Linear(10, 2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) ema = ExponentialMovingAverage(model.parameters(), decay=0.995) # Train for a few epochs model.train() -for _ in range(10): +for _ in range(20): logits = model(x_train) loss = F.cross_entropy(logits, y_train) optimizer.zero_grad() @@ -38,16 +40,20 @@ for _ in range(10): optimizer.step() ema.update(model.parameters()) -# Compare losses: -# Original +# Validation: original model.eval() -logits = model(x_train) -loss = F.cross_entropy(logits, y_train) +logits = model(x_val) +loss = F.cross_entropy(logits, y_val) print(loss.item()) -# With EMA +# Validation: with EMA +# First save original parameters before replacing with EMA version +ema.store(model.parameters()) +# Copy EMA parameters to model ema.copy_to(model.parameters()) -logits = model(x_train) -loss = F.cross_entropy(logits, y_train) +logits = model(x_val) +loss = F.cross_entropy(logits, y_val) print(loss.item()) +# Restore original parameters to resume training later +ema.restore(model.parameters()) ``` |