diff options
-rw-r--r-- | torch_ema/ema.py | 31 |
1 files changed, 16 insertions, 15 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 2aa3004..3bcb465 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy +import contextlib import torch @@ -79,10 +80,10 @@ class ExponentialMovingAverage: the `optimizer.step()` call. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay @@ -109,10 +110,10 @@ class ExponentialMovingAverage: Copy current averaged parameters into given collection of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): @@ -127,9 +128,9 @@ class ExponentialMovingAverage: Save the current parameters for restoring later. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.collected_params = [ @@ -150,10 +151,10 @@ class ExponentialMovingAverage: restore the former parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ if self.collected_params is None: raise RuntimeError( |