diff options
-rw-r--r-- | README.md | 24 |
1 files changed, 15 insertions, 9 deletions
@@ -21,16 +21,18 @@ import torch.nn.functional as F from torch_ema import ExponentialMovingAverage - +torch.manual_seed(0) x_train = torch.rand((100, 10)) y_train = torch.rand(100).round().long() +x_val = torch.rand((100, 10)) +y_val = torch.rand(100).round().long() model = torch.nn.Linear(10, 2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) ema = ExponentialMovingAverage(model.parameters(), decay=0.995) # Train for a few epochs model.train() -for _ in range(10): +for _ in range(20): logits = model(x_train) loss = F.cross_entropy(logits, y_train) optimizer.zero_grad() @@ -38,16 +40,20 @@ for _ in range(10): optimizer.step() ema.update(model.parameters()) -# Compare losses: -# Original +# Validation: original model.eval() -logits = model(x_train) -loss = F.cross_entropy(logits, y_train) +logits = model(x_val) +loss = F.cross_entropy(logits, y_val) print(loss.item()) -# With EMA +# Validation: with EMA +# First save original parameters before replacing with EMA version +ema.store(model.parameters()) +# Copy EMA parameters to model ema.copy_to(model.parameters()) -logits = model(x_train) -loss = F.cross_entropy(logits, y_train) +logits = model(x_val) +loss = F.cross_entropy(logits, y_val) print(loss.item()) +# Restore original parameters to resume training later +ema.restore(model.parameters()) ``` |