aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/test_ema.py64
-rw-r--r--torch_ema/ema.py73
2 files changed, 113 insertions, 24 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index 6d7e43e..ad6ee37 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -7,7 +7,8 @@ from torch_ema import ExponentialMovingAverage
@pytest.mark.parametrize("decay", [0.995, 0.9])
@pytest.mark.parametrize("use_num_updates", [True, False])
-def test_val_error(decay, use_num_updates):
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_val_error(decay, use_num_updates, explicit_params):
"""Confirm that EMA validation error is lower than raw validation error."""
torch.manual_seed(0)
x_train = torch.rand((100, 10))
@@ -30,27 +31,37 @@ def test_val_error(decay, use_num_updates):
optimizer.zero_grad()
loss.backward()
optimizer.step()
- ema.update(model.parameters())
+ if explicit_params:
+ ema.update(model.parameters())
+ else:
+ ema.update()
# Validation: original
model.eval()
logits = model(x_val)
loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
- print(f"Original loss: {loss_orig}")
# Validation: with EMA
# First save original parameters before replacing with EMA version
- ema.store(model.parameters())
+ if explicit_params:
+ ema.store(model.parameters())
+ else:
+ ema.store()
# Copy EMA parameters to model
- ema.copy_to(model.parameters())
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
logits = model(x_val)
loss_ema = torch.nn.functional.cross_entropy(logits, y_val)
- print(f"EMA loss: {loss_ema}")
assert loss_ema < loss_orig, "EMA loss wasn't lower"
# Test restore
- ema.restore(model.parameters())
+ if explicit_params:
+ ema.restore(model.parameters())
+ else:
+ ema.restore()
model.eval()
logits = model(x_val)
loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val)
@@ -60,7 +71,8 @@ def test_val_error(decay, use_num_updates):
@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
@pytest.mark.parametrize("use_num_updates", [True, False])
-def test_store_restore(decay, use_num_updates):
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_store_restore(decay, use_num_updates, explicit_params):
model = torch.nn.Linear(10, 2)
ema = ExponentialMovingAverage(
model.parameters(),
@@ -68,15 +80,22 @@ def test_store_restore(decay, use_num_updates):
use_num_updates=use_num_updates
)
orig_weight = model.weight.clone().detach()
- ema.store(model.parameters())
+ if explicit_params:
+ ema.store(model.parameters())
+ else:
+ ema.store()
with torch.no_grad():
model.weight.uniform_(0.0, 1.0)
- ema.restore(model.parameters())
+ if explicit_params:
+ ema.restore(model.parameters())
+ else:
+ ema.restore()
assert torch.all(model.weight == orig_weight)
@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
-def test_update(decay):
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_update(decay, explicit_params):
model = torch.nn.Linear(10, 2, bias=False)
with torch.no_grad():
model.weight.fill_(0.0)
@@ -87,10 +106,29 @@ def test_update(decay):
)
with torch.no_grad():
model.weight.fill_(1.0)
- ema.update(model.parameters())
+ if explicit_params:
+ ema.update(model.parameters())
+ else:
+ ema.update()
assert torch.all(model.weight == 1.0), "ema.update changed model weights"
- ema.copy_to(model.parameters())
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
assert torch.allclose(
model.weight,
torch.full(size=(1,), fill_value=(1.0 - decay))
), "average was wrong"
+
+
+def test_explicit_params():
+ model = torch.nn.Linear(10, 2)
+ with torch.no_grad():
+ model.weight.fill_(0.0)
+ ema = ExponentialMovingAverage(model.parameters(), decay=0.9)
+ model2 = torch.nn.Linear(10, 2)
+ with torch.no_grad():
+ model2.weight.fill_(1.0)
+ ema.update(model2.parameters())
+ ema.copy_to()
+ assert not torch.all(model.weight == 0.0) \ No newline at end of file
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 0233c78..2e8eb6f 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -1,7 +1,8 @@
from __future__ import division
from __future__ import unicode_literals
-from typing import Iterable
+from typing import Iterable, Optional
+import weakref
import torch
@@ -13,8 +14,8 @@ class ExponentialMovingAverage:
Maintains (exponential) moving average of a set of parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; usually the result of
- `model.parameters()`.
+ parameters: Iterable of `torch.nn.Parameter` (typically from
+ `model.parameters()`).
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
@@ -29,11 +30,40 @@ class ExponentialMovingAverage:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
+ parameters = list(parameters)
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
+ # By maintaining only a weakref to each parameter,
+ # we maintain the old GC behaviour of ExponentialMovingAverage:
+ # if the model goes out of scope but the ExponentialMovingAverage
+ # is kept, no references to the model or its parameters will be
+ # maintained, and the model will be cleaned up.
+ self._params_refs = [weakref.ref(p) for p in parameters]
- def update(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ def _get_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]]
+ ) -> Iterable[torch.nn.Parameter]:
+ if parameters is None:
+ parameters = [p() for p in self._params_refs]
+ if any(p is None for p in parameters):
+ raise ValueError(
+ "(One of) the parameters with which this "
+ "ExponentialMovingAverage "
+ "was initialized no longer exists (was garbage collected);"
+ " please either provide `parameters` explicitly or keep "
+ "the model to which they belong from being garbage "
+ "collected."
+ )
+ return parameters
+ else:
+ return parameters
+
+ def update(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
"""
Update currently maintained parameters.
@@ -42,8 +72,11 @@ class ExponentialMovingAverage:
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
- parameters used to initialize this object.
+ parameters used to initialize this object. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
+ parameters = self._get_parameters(parameters)
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
@@ -60,31 +93,46 @@ class ExponentialMovingAverage:
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)
- def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ def copy_to(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
"""
Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored moving averages.
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
+ parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
- def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+ def store(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored.
+ temporarily stored. If `None`, the parameters of with which this
+ `ExponentialMovingAverage` was initialized will be used.
"""
+ parameters = self._get_parameters(parameters)
self.collected_params = [param.clone()
for param in parameters
if param.requires_grad]
- def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
+
+ def restore(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
@@ -94,8 +142,11 @@ class ExponentialMovingAverage:
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters.
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
+ parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)