aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-03-03 13:38:52 +0100
committerSamuel Fadel <samuelfadel@gmail.com>2021-03-03 13:38:52 +0100
commit30fb07f4d277fe70cd7596c9be98faf3c30f52fc (patch)
tree8e7db763f6c5cfda8c666a1f584456f6c2a7f434
parent18f4ea37fe6993b0e3a4a45d28421d1e223cfeac (diff)
Updated README.md with store/restore functionality.
-rw-r--r--README.md24
1 files changed, 15 insertions, 9 deletions
diff --git a/README.md b/README.md
index 09f3e18..a74db20 100644
--- a/README.md
+++ b/README.md
@@ -21,16 +21,18 @@ import torch.nn.functional as F
from torch_ema import ExponentialMovingAverage
-
+torch.manual_seed(0)
x_train = torch.rand((100, 10))
y_train = torch.rand(100).round().long()
+x_val = torch.rand((100, 10))
+y_val = torch.rand(100).round().long()
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
# Train for a few epochs
model.train()
-for _ in range(10):
+for _ in range(20):
logits = model(x_train)
loss = F.cross_entropy(logits, y_train)
optimizer.zero_grad()
@@ -38,16 +40,20 @@ for _ in range(10):
optimizer.step()
ema.update(model.parameters())
-# Compare losses:
-# Original
+# Validation: original
model.eval()
-logits = model(x_train)
-loss = F.cross_entropy(logits, y_train)
+logits = model(x_val)
+loss = F.cross_entropy(logits, y_val)
print(loss.item())
-# With EMA
+# Validation: with EMA
+# First save original parameters before replacing with EMA version
+ema.store(model.parameters())
+# Copy EMA parameters to model
ema.copy_to(model.parameters())
-logits = model(x_train)
-loss = F.cross_entropy(logits, y_train)
+logits = model(x_val)
+loss = F.cross_entropy(logits, y_val)
print(loss.item())
+# Restore original parameters to resume training later
+ema.restore(model.parameters())
```