diff options
-rw-r--r-- | tests/test_ema.py | 64 | ||||
-rw-r--r-- | torch_ema/ema.py | 73 |
2 files changed, 113 insertions, 24 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index 6d7e43e..ad6ee37 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -7,7 +7,8 @@ from torch_ema import ExponentialMovingAverage @pytest.mark.parametrize("decay", [0.995, 0.9]) @pytest.mark.parametrize("use_num_updates", [True, False]) -def test_val_error(decay, use_num_updates): +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_val_error(decay, use_num_updates, explicit_params): """Confirm that EMA validation error is lower than raw validation error.""" torch.manual_seed(0) x_train = torch.rand((100, 10)) @@ -30,27 +31,37 @@ def test_val_error(decay, use_num_updates): optimizer.zero_grad() loss.backward() optimizer.step() - ema.update(model.parameters()) + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() # Validation: original model.eval() logits = model(x_val) loss_orig = torch.nn.functional.cross_entropy(logits, y_val) - print(f"Original loss: {loss_orig}") # Validation: with EMA # First save original parameters before replacing with EMA version - ema.store(model.parameters()) + if explicit_params: + ema.store(model.parameters()) + else: + ema.store() # Copy EMA parameters to model - ema.copy_to(model.parameters()) + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() logits = model(x_val) loss_ema = torch.nn.functional.cross_entropy(logits, y_val) - print(f"EMA loss: {loss_ema}") assert loss_ema < loss_orig, "EMA loss wasn't lower" # Test restore - ema.restore(model.parameters()) + if explicit_params: + ema.restore(model.parameters()) + else: + ema.restore() model.eval() logits = model(x_val) loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val) @@ -60,7 +71,8 @@ def test_val_error(decay, use_num_updates): @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) @pytest.mark.parametrize("use_num_updates", [True, False]) -def test_store_restore(decay, use_num_updates): +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_store_restore(decay, use_num_updates, explicit_params): model = torch.nn.Linear(10, 2) ema = ExponentialMovingAverage( model.parameters(), @@ -68,15 +80,22 @@ def test_store_restore(decay, use_num_updates): use_num_updates=use_num_updates ) orig_weight = model.weight.clone().detach() - ema.store(model.parameters()) + if explicit_params: + ema.store(model.parameters()) + else: + ema.store() with torch.no_grad(): model.weight.uniform_(0.0, 1.0) - ema.restore(model.parameters()) + if explicit_params: + ema.restore(model.parameters()) + else: + ema.restore() assert torch.all(model.weight == orig_weight) @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) -def test_update(decay): +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_update(decay, explicit_params): model = torch.nn.Linear(10, 2, bias=False) with torch.no_grad(): model.weight.fill_(0.0) @@ -87,10 +106,29 @@ def test_update(decay): ) with torch.no_grad(): model.weight.fill_(1.0) - ema.update(model.parameters()) + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() assert torch.all(model.weight == 1.0), "ema.update changed model weights" - ema.copy_to(model.parameters()) + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() assert torch.allclose( model.weight, torch.full(size=(1,), fill_value=(1.0 - decay)) ), "average was wrong" + + +def test_explicit_params(): + model = torch.nn.Linear(10, 2) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage(model.parameters(), decay=0.9) + model2 = torch.nn.Linear(10, 2) + with torch.no_grad(): + model2.weight.fill_(1.0) + ema.update(model2.parameters()) + ema.copy_to() + assert not torch.all(model.weight == 0.0)
\ No newline at end of file diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 0233c78..2e8eb6f 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -1,7 +1,8 @@ from __future__ import division from __future__ import unicode_literals -from typing import Iterable +from typing import Iterable, Optional +import weakref import torch @@ -13,8 +14,8 @@ class ExponentialMovingAverage: Maintains (exponential) moving average of a set of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the result of - `model.parameters()`. + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). decay: The exponential decay. use_num_updates: Whether to use number of updates when computing averages. @@ -29,11 +30,40 @@ class ExponentialMovingAverage: raise ValueError('Decay must be between 0 and 1') self.decay = decay self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] - def update(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def _get_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this " + "ExponentialMovingAverage " + "was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep " + "the model to which they belong from being garbage " + "collected." + ) + return parameters + else: + return parameters + + def update( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Update currently maintained parameters. @@ -42,8 +72,11 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + parameters = self._get_parameters(parameters) decay = self.decay if self.num_updates is not None: self.num_updates += 1 @@ -60,31 +93,46 @@ class ExponentialMovingAverage: tmp.mul_(one_minus_decay) s_param.sub_(tmp) - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def copy_to( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Copy current parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: param.data.copy_(s_param.data) - def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: + def store( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Save the current parameters for restoring later. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ + parameters = self._get_parameters(parameters) self.collected_params = [param.clone() for param in parameters if param.requires_grad] - def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: + + def restore( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the @@ -94,8 +142,11 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) |