aboutsummaryrefslogtreecommitdiff
path: root/torch_ema
diff options
context:
space:
mode:
authorZehui Lin <zehui-lin_t@outlook.com>2021-02-08 10:09:36 +0800
committerZehui Lin <zehui-lin_t@outlook.com>2021-02-08 10:09:36 +0800
commit359c155723f0b08f6e6c1784c7951e33d92b306d (patch)
tree39f6a13c973df60c1192e11eb363488955446388 /torch_ema
parent0ed9f9ddc9b93bbcd6d3828803f20a23cc500f5e (diff)
minor
Diffstat (limited to 'torch_ema')
-rw-r--r--torch_ema/ema.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 8f39bc5..50f9aea 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -20,6 +20,7 @@ class ExponentialMovingAverage:
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
+ self.collected_parameters = []
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
@@ -54,9 +55,7 @@ class ExponentialMovingAverage:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
- self.collected_parameters = []
for s_param, param in zip(self.shadow_params, parameters):
- self.collected_parameters.append(param.clone())
if param.requires_grad:
param.data.copy_(s_param.data)