From 05fe2c74d7b85f62b838f740af0f87b8f4d3691f Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Sun, 26 May 2019 10:46:38 -0300 Subject: Filter by requires_grad before zip() in update(). --- torch_ema/ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'torch_ema') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 32ed7ca..4e9dd99 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -42,9 +42,9 @@ class ExponentialMovingAverage: decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): + parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) + s_param.sub_(one_minus_decay * (s_param - param)) def copy_to(self, parameters): """ -- cgit v1.2.3