aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-03-03 14:33:20 +0100
committerSamuel Fadel <samuelfadel@gmail.com>2021-03-03 14:33:20 +0100
commit3950a7b5c4b88f46fd14f620277bad21898597a9 (patch)
treef9269ca38c382cea8b83b0ad5f9d5df3a9c21aad /torch_ema/ema.py
parent30fb07f4d277fe70cd7596c9be98faf3c30f52fc (diff)
Version bump and only store params requiring grad.v0.2
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 447fd1e..7771ef7 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -67,7 +67,9 @@ class ExponentialMovingAverage:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
- self.collected_params = [param.clone() for param in parameters]
+ self.collected_params = [param.clone()
+ for param in parameters
+ if param.requires_grad]
def restore(self, parameters):
"""