From 30a1306f2e5dc6aa91b4b64c2a3acd1bb3b0d7b6 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:41:41 -0500 Subject: handle non-trainable parameters --- torch_ema/__init__.py | 4 +++- torch_ema/ema.py | 30 ++++++++++++++++++++++-------- 2 files changed, 25 insertions(+), 9 deletions(-) (limited to 'torch_ema') diff --git a/torch_ema/__init__.py b/torch_ema/__init__.py index 9732013..6cf180f 100644 --- a/torch_ema/__init__.py +++ b/torch_ema/__init__.py @@ -1 +1,3 @@ -from .ema import * +from .ema import ExponentialMovingAverage + +__all__ = [ExponentialMovingAverage] diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 1d03fd6..8077f80 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -18,7 +18,23 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter` (typically from `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing averages. """ @@ -33,8 +49,10 @@ class ExponentialMovingAverage: self.decay = decay self.num_updates = 0 if use_num_updates else None parameters = list(parameters) - self.shadow_params = [p.clone().detach() - for p in parameters if p.requires_grad] + self.shadow_params = [ + p.clone().detach() + for p in parameters + ] self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: @@ -95,7 +113,6 @@ class ExponentialMovingAverage: ) one_minus_decay = 1.0 - decay with torch.no_grad(): - parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): tmp = (s_param - param) # tmp will be a new tensor so we can do in-place @@ -117,8 +134,7 @@ class ExponentialMovingAverage: """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(s_param.data) def store( self, @@ -136,7 +152,6 @@ class ExponentialMovingAverage: self.collected_params = [ param.clone() for param in parameters - if param.requires_grad ] def restore( @@ -163,8 +178,7 @@ class ExponentialMovingAverage: ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): - if param.requires_grad: - param.data.copy_(c_param.data) + param.data.copy_(c_param.data) @contextlib.contextmanager def average_parameters( -- cgit v1.2.3