diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-07-02 13:23:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-02 13:23:07 +0200 |
commit | 3985995e523aa25dd3cff7e7984130eef90a4282 (patch) | |
tree | da848e06c121f731542b969bbe6d576dd5304369 /torch_ema | |
parent | 72ac3d3333c3dc1d95eacdedbdb5a0132958973a (diff) | |
parent | 98758f465aa319c0880cc948f34d1b59e8dd4550 (diff) |
Merge pull request #8 from Linux-cpp-lisp/context_manager
Context manager & README updates
Diffstat (limited to 'torch_ema')
-rw-r--r-- | torch_ema/ema.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 3bcb465..1d03fd6 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -166,6 +166,37 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. |