diff options
author | Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-11-17 15:41:41 -0500 |
---|---|---|
committer | Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-11-17 15:41:41 -0500 |
commit | 30a1306f2e5dc6aa91b4b64c2a3acd1bb3b0d7b6 (patch) | |
tree | 1204efe7c93c741136829e80628b1b42ac603e8f /torch_ema/ema.py | |
parent | 3985995e523aa25dd3cff7e7984130eef90a4282 (diff) |
handle non-trainable parameters
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r-- | torch_ema/ema.py | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 1d03fd6..8077f80 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -18,7 +18,23 @@ class ExponentialMovingAverage: Args: parameters: Iterable of `torch.nn.Parameter` (typically from `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing averages. """ @@ -33,8 +49,10 @@ class ExponentialMovingAverage: self.decay = decay self.num_updates = 0 if use_num_updates else None parameters = list(parameters) - self.shadow_params = [p.clone().detach() - for p in parameters if p.requires_grad] + self.shadow_params = [ + p.clone().detach() + for p in parameters + ] self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: @@ -95,7 +113,6 @@ class ExponentialMovingAverage: ) one_minus_decay = 1.0 - decay with torch.no_grad(): - parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): tmp = (s_param - param) # tmp will be a new tensor so we can do in-place @@ -117,8 +134,7 @@ class ExponentialMovingAverage: """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - param.data.copy_(s_param.data) + param.data.copy_(s_param.data) def store( self, @@ -136,7 +152,6 @@ class ExponentialMovingAverage: self.collected_params = [ param.clone() for param in parameters - if param.requires_grad ] def restore( @@ -163,8 +178,7 @@ class ExponentialMovingAverage: ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): - if param.requires_grad: - param.data.copy_(c_param.data) + param.data.copy_(c_param.data) @contextlib.contextmanager def average_parameters( |