diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-03-03 14:33:20 +0100 |
---|---|---|
committer | Samuel Fadel <samuelfadel@gmail.com> | 2021-03-03 14:33:20 +0100 |
commit | 3950a7b5c4b88f46fd14f620277bad21898597a9 (patch) | |
tree | f9269ca38c382cea8b83b0ad5f9d5df3a9c21aad | |
parent | 30fb07f4d277fe70cd7596c9be98faf3c30f52fc (diff) |
Version bump and only store params requiring grad.v0.2
-rw-r--r-- | setup.py | 2 | ||||
-rw-r--r-- | torch_ema/ema.py | 4 |
2 files changed, 4 insertions, 2 deletions
@@ -1,6 +1,6 @@ from setuptools import setup, find_packages -__version__ = '0.1' +__version__ = '0.2' url = 'https://github.com/fadel/pytorch_ema' download_url = '{}/archive/{}.tar.gz'.format(url, __version__) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 447fd1e..7771ef7 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -67,7 +67,9 @@ class ExponentialMovingAverage: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ - self.collected_params = [param.clone() for param in parameters] + self.collected_params = [param.clone() + for param in parameters + if param.requires_grad] def restore(self, parameters): """ |