aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 447fd1e..7771ef7 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -67,7 +67,9 @@ class ExponentialMovingAverage:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
- self.collected_params = [param.clone() for param in parameters]
+ self.collected_params = [param.clone()
+ for param in parameters
+ if param.requires_grad]
def restore(self, parameters):
"""