aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-03-03 14:33:20 +0100
committerSamuel Fadel <samuelfadel@gmail.com>2021-03-03 14:33:20 +0100
commit3950a7b5c4b88f46fd14f620277bad21898597a9 (patch)
treef9269ca38c382cea8b83b0ad5f9d5df3a9c21aad
parent30fb07f4d277fe70cd7596c9be98faf3c30f52fc (diff)
Version bump and only store params requiring grad.v0.2
-rw-r--r--setup.py2
-rw-r--r--torch_ema/ema.py4
2 files changed, 4 insertions, 2 deletions
diff --git a/setup.py b/setup.py
index 8d6fa9e..d998024 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
from setuptools import setup, find_packages
-__version__ = '0.1'
+__version__ = '0.2'
url = 'https://github.com/fadel/pytorch_ema'
download_url = '{}/archive/{}.tar.gz'.format(url, __version__)
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 447fd1e..7771ef7 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -67,7 +67,9 @@ class ExponentialMovingAverage:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
- self.collected_params = [param.clone() for param in parameters]
+ self.collected_params = [param.clone()
+ for param in parameters
+ if param.requires_grad]
def restore(self, parameters):
"""