diff options
Diffstat (limited to 'torch_ema')
-rw-r--r-- | torch_ema/ema.py | 42 |
1 files changed, 41 insertions, 1 deletions
diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 2e8eb6f..6c0415f 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref +import copy import torch @@ -128,7 +129,6 @@ class ExponentialMovingAverage: for param in parameters if param.requires_grad] - def restore( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None @@ -150,3 +150,43 @@ class ExponentialMovingAverage: for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance(self.num_updates, int), \ + "Invalid num_updates" + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), \ + "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + self.collected_params = state_dict["collected_params"] + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" |