aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 14:31:58 -0600
committerAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 14:31:58 -0600
commit5bf836aba916fa368bed96b56bc5a5048e8b99bd (patch)
tree178e9c9058353578d8d8f5d99507b2998f01e4d8 /torch_ema/ema.py
parente668ae1e0a757cf8217e926be9ae228676fbe17b (diff)
context manager for validation
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 3bcb465..1d03fd6 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -166,6 +166,37 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(c_param.data)
+ @contextlib.contextmanager
+ def average_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ):
+ r"""
+ Context manager for validation/inference with averaged parameters.
+
+ Equivalent to:
+
+ ema.store()
+ ema.copy_to()
+ try:
+ ...
+ finally:
+ ema.restore()
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.store(parameters)
+ self.copy_to(parameters)
+ try:
+ yield
+ finally:
+ self.restore(parameters)
+
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.