aboutsummaryrefslogtreecommitdiff
path: root/torch_ema/ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'torch_ema/ema.py')
-rw-r--r--torch_ema/ema.py31
1 files changed, 16 insertions, 15 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 2aa3004..3bcb465 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from typing import Iterable, Optional
import weakref
import copy
+import contextlib
import torch
@@ -79,10 +80,10 @@ class ExponentialMovingAverage:
the `optimizer.step()` call.
Args:
- parameters: Iterable of `torch.nn.Parameter`; usually the same set of
- parameters used to initialize this object. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
parameters = self._get_parameters(parameters)
decay = self.decay
@@ -109,10 +110,10 @@ class ExponentialMovingAverage:
Copy current averaged parameters into given collection of parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored moving averages. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
@@ -127,9 +128,9 @@ class ExponentialMovingAverage:
Save the current parameters for restoring later.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored. If `None`, the parameters of with which this
- `ExponentialMovingAverage` was initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored. If `None`, the parameters of with which this
+ `ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.collected_params = [
@@ -150,10 +151,10 @@ class ExponentialMovingAverage:
restore the former parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
if self.collected_params is None:
raise RuntimeError(