aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2019-05-26 10:46:38 -0300
committerSamuel Fadel <samuelfadel@gmail.com>2019-05-26 10:46:38 -0300
commit05fe2c74d7b85f62b838f740af0f87b8f4d3691f (patch)
tree4aa2cd5f4d06135c534a41d4afb3e4896c521717 /torch_ema/ema.py
parent09cfcf97e0e938a93867c7d445f1c9b4dcfea023 (diff)
Filter by requires_grad before zip() in update().
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 32ed7ca..4e9dd99 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -42,9 +42,9 @@ class ExponentialMovingAverage:
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
+ parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
- if param.requires_grad:
- s_param.sub_(one_minus_decay * (s_param - param))
+ s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
"""