diff options
-rw-r--r-- | tests/test_ema.py | 21 | ||||
-rw-r--r-- | torch_ema/__init__.py | 4 | ||||
-rw-r--r-- | torch_ema/ema.py | 30 |
3 files changed, 46 insertions, 9 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index fa90a8c..4bc1901 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -187,6 +187,27 @@ def test_explicit_params(): assert not torch.all(model.weight == 0.0) +def test_some_untrainable(): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(3)) + self.y = torch.nn.Parameter(torch.randn(3)) + self.y.requires_grad_(False) + + def forward(self, x): + return self.x * x + self.y + + model = Mod() + ema = ExponentialMovingAverage(model.parameters(), decay=0.9) + ema.update() + with torch.no_grad(): + model.x *= 1.1 + ema.update() + ema.store() + ema.copy_to() + + def test_to(): m = torch.nn.Linear(11, 3) ema = ExponentialMovingAverage(m.parameters(), decay=0.9) diff --git a/torch_ema/__init__.py b/torch_ema/__init__.py index 9732013..6cf180f 100644 --- a/torch_ema/__init__.py +++ b/torch_ema/__init__.py @@ -1 +1,3 @@ -from .ema import * +from .ema import ExponentialMovingAverage + +__all__ = [ExponentialMovingAverage] diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 1d03fd6..8077f80 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -18,7 +18,23 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter` (typically from `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing averages. """ @@ -33,8 +49,10 @@ class ExponentialMovingAverage: self.decay = decay self.num_updates = 0 if use_num_updates else None parameters = list(parameters) - self.shadow_params = [p.clone().detach() - for p in parameters if p.requires_grad] + self.shadow_params = [ + p.clone().detach() + for p in parameters + ] self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: @@ -95,7 +113,6 @@ class ExponentialMovingAverage: ) one_minus_decay = 1.0 - decay with torch.no_grad(): - parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): tmp = (s_param - param) # tmp will be a new tensor so we can do in-place @@ -117,8 +134,7 @@ class ExponentialMovingAverage: """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(s_param.data) def store( self, @@ -136,7 +152,6 @@ class ExponentialMovingAverage: self.collected_params = [ param.clone() for param in parameters - if param.requires_grad ] def restore( @@ -163,8 +178,7 @@ class ExponentialMovingAverage: ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): - if param.requires_grad: - param.data.copy_(c_param.data) + param.data.copy_(c_param.data) @contextlib.contextmanager def average_parameters( |