From 30fb07f4d277fe70cd7596c9be98faf3c30f52fc Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Wed, 3 Mar 2021 13:38:52 +0100 Subject: Updated README.md with store/restore functionality. --- README.md | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 09f3e18..a74db20 100644 --- a/README.md +++ b/README.md @@ -21,16 +21,18 @@ import torch.nn.functional as F from torch_ema import ExponentialMovingAverage - +torch.manual_seed(0) x_train = torch.rand((100, 10)) y_train = torch.rand(100).round().long() +x_val = torch.rand((100, 10)) +y_val = torch.rand(100).round().long() model = torch.nn.Linear(10, 2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) ema = ExponentialMovingAverage(model.parameters(), decay=0.995) # Train for a few epochs model.train() -for _ in range(10): +for _ in range(20): logits = model(x_train) loss = F.cross_entropy(logits, y_train) optimizer.zero_grad() @@ -38,16 +40,20 @@ for _ in range(10): optimizer.step() ema.update(model.parameters()) -# Compare losses: -# Original +# Validation: original model.eval() -logits = model(x_train) -loss = F.cross_entropy(logits, y_train) +logits = model(x_val) +loss = F.cross_entropy(logits, y_val) print(loss.item()) -# With EMA +# Validation: with EMA +# First save original parameters before replacing with EMA version +ema.store(model.parameters()) +# Copy EMA parameters to model ema.copy_to(model.parameters()) -logits = model(x_train) -loss = F.cross_entropy(logits, y_train) +logits = model(x_val) +loss = F.cross_entropy(logits, y_val) print(loss.item()) +# Restore original parameters to resume training later +ema.restore(model.parameters()) ``` -- cgit v1.2.3