aboutsummaryrefslogtreecommitdiff
path: root/torch_ema
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-20 12:03:13 -0600
committerGitHub <noreply@github.com>2021-04-20 20:03:13 +0200
commit1a47d1c99784d1247f41308d335b2087de13334c (patch)
tree78803311cbf29ac47cac1a107e121a20d0af73a7 /torch_ema
parent3950a7b5c4b88f46fd14f620277bad21898597a9 (diff)
Tests and miscellaneous (#4)
* linter * Add tests * Add restore test * Type hints * PyTorch dependency * One fewer temp buffer
Diffstat (limited to 'torch_ema')
-rw-r--r--torch_ema/ema.py50
1 files changed, 31 insertions, 19 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 7771ef7..0233c78 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -1,23 +1,30 @@
from __future__ import division
from __future__ import unicode_literals
+from typing import Iterable
+
import torch
-# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
+# Partially based on:
+# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
- """
- def __init__(self, parameters, decay, use_num_updates=True):
- """
- Args:
- parameters: Iterable of `torch.nn.Parameter`; usually the result of
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
`model.parameters()`.
- decay: The exponential decay.
- use_num_updates: Whether to use number of updates when computing
+ decay: The exponential decay.
+ use_num_updates: Whether to use number of updates when computing
averages.
- """
+ """
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float,
+ use_num_updates: bool = True
+ ):
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
@@ -26,7 +33,7 @@ class ExponentialMovingAverage:
for p in parameters if p.requires_grad]
self.collected_params = []
- def update(self, parameters):
+ def update(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Update currently maintained parameters.
@@ -40,18 +47,24 @@ class ExponentialMovingAverage:
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
- decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
+ decay = min(
+ decay,
+ (1 + self.num_updates) / (10 + self.num_updates)
+ )
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
- s_param.sub_(one_minus_decay * (s_param - param))
+ tmp = (s_param - param)
+ # tmp will be a new tensor so we can do in-place
+ tmp.mul_(one_minus_decay)
+ s_param.sub_(tmp)
- def copy_to(self, parameters):
+ def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current parameters into given collection of parameters.
- Args:
+ Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
@@ -59,11 +72,11 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(s_param.data)
- def store(self, parameters):
+ def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Save the current parameters for restoring later.
- Args:
+ Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
@@ -71,7 +84,7 @@ class ExponentialMovingAverage:
for param in parameters
if param.requires_grad]
- def restore(self, parameters):
+ def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
@@ -79,11 +92,10 @@ class ExponentialMovingAverage:
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
- Args:
+ Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)
-