diff options
author | Zehui Lin <zehui-lin_t@outlook.com> | 2021-02-08 10:09:36 +0800 |
---|---|---|
committer | Zehui Lin <zehui-lin_t@outlook.com> | 2021-02-08 10:09:36 +0800 |
commit | 359c155723f0b08f6e6c1784c7951e33d92b306d (patch) | |
tree | 39f6a13c973df60c1192e11eb363488955446388 /torch_ema | |
parent | 0ed9f9ddc9b93bbcd6d3828803f20a23cc500f5e (diff) |
minor
Diffstat (limited to 'torch_ema')
-rw-r--r-- | torch_ema/ema.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 8f39bc5..50f9aea 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -20,6 +20,7 @@ class ExponentialMovingAverage: """ if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') + self.collected_parameters = [] self.decay = decay self.num_updates = 0 if use_num_updates else None self.shadow_params = [p.clone().detach() @@ -54,9 +55,7 @@ class ExponentialMovingAverage: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. """ - self.collected_parameters = [] for s_param, param in zip(self.shadow_params, parameters): - self.collected_parameters.append(param.clone()) if param.requires_grad: param.data.copy_(s_param.data) |