From 3950a7b5c4b88f46fd14f620277bad21898597a9 Mon Sep 17 00:00:00 2001 From: Samuel Fadel Date: Wed, 3 Mar 2021 14:33:20 +0100 Subject: Version bump and only store params requiring grad. --- torch_ema/ema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'torch_ema/ema.py') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 447fd1e..7771ef7 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -67,7 +67,9 @@ class ExponentialMovingAverage: parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ - self.collected_params = [param.clone() for param in parameters] + self.collected_params = [param.clone() + for param in parameters + if param.requires_grad] def restore(self, parameters): """ -- cgit v1.2.3