aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-07-02 13:23:07 +0200
committerGitHub <noreply@github.com>2021-07-02 13:23:07 +0200
commit3985995e523aa25dd3cff7e7984130eef90a4282 (patch)
treeda848e06c121f731542b969bbe6d576dd5304369 /torch_ema/ema.py
parent72ac3d3333c3dc1d95eacdedbdb5a0132958973a (diff)
parent98758f465aa319c0880cc948f34d1b59e8dd4550 (diff)
Merge pull request #8 from Linux-cpp-lisp/context_manager
Context manager & README updates
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 3bcb465..1d03fd6 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -166,6 +166,37 @@ class ExponentialMovingAverage:
if param.requires_grad:
param.data.copy_(c_param.data)
+ @contextlib.contextmanager
+ def average_parameters(
+ self,
+ parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ):
+ r"""
+ Context manager for validation/inference with averaged parameters.
+
+ Equivalent to:
+
+ ema.store()
+ ema.copy_to()
+ try:
+ ...
+ finally:
+ ema.restore()
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.store(parameters)
+ self.copy_to(parameters)
+ try:
+ yield
+ finally:
+ self.restore(parameters)
+
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.