aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 1d03fd6..8077f80 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -18,7 +18,23 @@ class ExponentialMovingAverage:
Args:
parameters: Iterable of `torch.nn.Parameter` (typically from
`model.parameters()`).
+ Note that EMA is computed on *all* provided parameters,
+ regardless of whether or not they have `requires_grad = True`;
+ this allows a single EMA object to be consistantly used even
+ if which parameters are trainable changes step to step.
+
+ If you want to some parameters in the EMA, do not pass them
+ to the object in the first place. For example:
+
+ ExponentialMovingAverage(
+ parameters=[p for p in model.parameters() if p.requires_grad],
+ decay=0.9
+ )
+
+ will ignore parameters that do not require grad.
+
decay: The exponential decay.
+
use_num_updates: Whether to use number of updates when computing
averages.
"""
@@ -33,8 +49,10 @@ class ExponentialMovingAverage:
self.decay = decay
self.num_updates = 0 if use_num_updates else None
parameters = list(parameters)
- self.shadow_params = [p.clone().detach()
- for p in parameters if p.requires_grad]
+ self.shadow_params = [
+ p.clone().detach()
+ for p in parameters
+ ]
self.collected_params = None
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
@@ -95,7 +113,6 @@ class ExponentialMovingAverage:
)
one_minus_decay = 1.0 - decay
with torch.no_grad():
- parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
tmp = (s_param - param)
# tmp will be a new tensor so we can do in-place
@@ -117,8 +134,7 @@ class ExponentialMovingAverage:
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
- if param.requires_grad:
- param.data.copy_(s_param.data)
+ param.data.copy_(s_param.data)
def store(
self,
@@ -136,7 +152,6 @@ class ExponentialMovingAverage:
self.collected_params = [
param.clone()
for param in parameters
- if param.requires_grad
]
def restore(
@@ -163,8 +178,7 @@ class ExponentialMovingAverage:
)
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
- if param.requires_grad:
- param.data.copy_(c_param.data)
+ param.data.copy_(c_param.data)
@contextlib.contextmanager
def average_parameters(