aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--setup.py2
-rw-r--r--tests/test_ema.py96
-rw-r--r--torch_ema/ema.py50
3 files changed, 128 insertions, 20 deletions
diff --git a/setup.py b/setup.py
index d998024..550c159 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@ __version__ = '0.2'
url = 'https://github.com/fadel/pytorch_ema'
download_url = '{}/archive/{}.tar.gz'.format(url, __version__)
-install_requires = []
+install_requires = ["torch"]
setup_requires = []
tests_require = []
diff --git a/tests/test_ema.py b/tests/test_ema.py
new file mode 100644
index 0000000..6d7e43e
--- /dev/null
+++ b/tests/test_ema.py
@@ -0,0 +1,96 @@
+import pytest
+
+import torch
+
+from torch_ema import ExponentialMovingAverage
+
+
+@pytest.mark.parametrize("decay", [0.995, 0.9])
+@pytest.mark.parametrize("use_num_updates", [True, False])
+def test_val_error(decay, use_num_updates):
+ """Confirm that EMA validation error is lower than raw validation error."""
+ torch.manual_seed(0)
+ x_train = torch.rand((100, 10))
+ y_train = torch.rand(100).round().long()
+ x_val = torch.rand((100, 10))
+ y_val = torch.rand(100).round().long()
+ model = torch.nn.Linear(10, 2)
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=decay,
+ use_num_updates=use_num_updates
+ )
+
+ # Train for a few epochs
+ model.train()
+ for _ in range(20):
+ logits = model(x_train)
+ loss = torch.nn.functional.cross_entropy(logits, y_train)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ ema.update(model.parameters())
+
+ # Validation: original
+ model.eval()
+ logits = model(x_val)
+ loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
+ print(f"Original loss: {loss_orig}")
+
+ # Validation: with EMA
+ # First save original parameters before replacing with EMA version
+ ema.store(model.parameters())
+ # Copy EMA parameters to model
+ ema.copy_to(model.parameters())
+ logits = model(x_val)
+ loss_ema = torch.nn.functional.cross_entropy(logits, y_val)
+
+ print(f"EMA loss: {loss_ema}")
+ assert loss_ema < loss_orig, "EMA loss wasn't lower"
+
+ # Test restore
+ ema.restore(model.parameters())
+ model.eval()
+ logits = model(x_val)
+ loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val)
+ assert torch.allclose(loss_orig, loss_orig2), \
+ "Restored model wasn't the same as stored model"
+
+
+@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
+@pytest.mark.parametrize("use_num_updates", [True, False])
+def test_store_restore(decay, use_num_updates):
+ model = torch.nn.Linear(10, 2)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=decay,
+ use_num_updates=use_num_updates
+ )
+ orig_weight = model.weight.clone().detach()
+ ema.store(model.parameters())
+ with torch.no_grad():
+ model.weight.uniform_(0.0, 1.0)
+ ema.restore(model.parameters())
+ assert torch.all(model.weight == orig_weight)
+
+
+@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
+def test_update(decay):
+ model = torch.nn.Linear(10, 2, bias=False)
+ with torch.no_grad():
+ model.weight.fill_(0.0)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=decay,
+ use_num_updates=False
+ )
+ with torch.no_grad():
+ model.weight.fill_(1.0)
+ ema.update(model.parameters())
+ assert torch.all(model.weight == 1.0), "ema.update changed model weights"
+ ema.copy_to(model.parameters())
+ assert torch.allclose(
+ model.weight,
+ torch.full(size=(1,), fill_value=(1.0 - decay))
+ ), "average was wrong"
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 7771ef7..0233c78 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -1,23 +1,30 @@
from __future__ import division
from __future__ import unicode_literals
+from typing import Iterable
+
import torch
-# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
+# Partially based on:
+# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
- """
- def __init__(self, parameters, decay, use_num_updates=True):
- """
- Args:
- parameters: Iterable of `torch.nn.Parameter`; usually the result of
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
`model.parameters()`.
- decay: The exponential decay.
- use_num_updates: Whether to use number of updates when computing
+ decay: The exponential decay.
+ use_num_updates: Whether to use number of updates when computing
averages.
- """
+ """
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float,
+ use_num_updates: bool = True
+ ):
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
@@ -26,7 +33,7 @@ class ExponentialMovingAverage:
for p in parameters if p.requires_grad]
self.collected_params = []
- def update(self, parameters):
+ def update(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Update currently maintained parameters.
@@ -40,18 +47,24 @@ class ExponentialMovingAverage:
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
- decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
+ decay = min(
+ decay,
+ (1 + self.num_updates) / (10 + self.num_updates)
+ )
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
- s_param.sub_(one_minus_decay * (s_param - param))
+ tmp = (s_param - param)
+ # tmp will be a new tensor so we can do in-place
+ tmp.mul_(one_minus_decay)
+ s_param.sub_(tmp)
- def copy_to(self, parameters):
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current parameters into given collection of parameters.
- Args:
+ Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
@@ -59,11 +72,11 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(s_param.data)
- def store(self, parameters):
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Save the current parameters for restoring later.
- Args:
+ Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
@@ -71,7 +84,7 @@ class ExponentialMovingAverage:
for param in parameters
if param.requires_grad]
- def restore(self, parameters):
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
@@ -79,11 +92,10 @@ class ExponentialMovingAverage:
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
- Args:
+ Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)
-