From a8223abaad1da1293f350d80b636a8d67b2d58a5 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 07:06:20 -0600 Subject: State dict support (#6) --- tests/test_ema.py | 46 +++++++++++++++++++++++++++++++++++++++++++++- torch_ema/ema.py | 42 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index ad6ee37..67a14dc 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -1,5 +1,7 @@ import pytest +import copy + import torch from torch_ema import ExponentialMovingAverage @@ -40,6 +42,7 @@ def test_val_error(decay, use_num_updates, explicit_params): model.eval() logits = model(x_val) loss_orig = torch.nn.functional.cross_entropy(logits, y_val) + print(f"Original loss: {loss_orig}") # Validation: with EMA # First save original parameters before replacing with EMA version @@ -55,6 +58,7 @@ def test_val_error(decay, use_num_updates, explicit_params): logits = model(x_val) loss_ema = torch.nn.functional.cross_entropy(logits, y_val) + print(f"EMA loss: {loss_ema}") assert loss_ema < loss_orig, "EMA loss wasn't lower" # Test restore @@ -131,4 +135,44 @@ def test_explicit_params(): model2.weight.fill_(1.0) ema.update(model2.parameters()) ema.copy_to() - assert not torch.all(model.weight == 0.0) \ No newline at end of file + assert not torch.all(model.weight == 0.0) + + +@pytest.mark.parametrize("decay", [0.995]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_state_dict(decay, use_num_updates, explicit_params): + model = torch.nn.Linear(10, 2, bias=False) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=False + ) + state_dict = copy.deepcopy(ema.state_dict()) + + model2 = torch.nn.Linear(10, 2, bias=False) + ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) + ema2.load_state_dict(state_dict) + assert ema2.decay == decay + assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) + + with torch.no_grad(): + model2.weight.fill_(1.0) + if explicit_params: + ema2.update(model2.parameters()) + else: + ema2.update() + assert torch.all(model2.weight == 1.0), "ema.update changed model weights" + + ema.load_state_dict(ema2.state_dict()) + + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() + assert torch.allclose( + model.weight, + torch.full(size=(1,), fill_value=(1.0 - decay)) + ), "average was wrong" diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 2e8eb6f..6c0415f 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref +import copy import torch @@ -128,7 +129,6 @@ class ExponentialMovingAverage: for param in parameters if param.requires_grad] - def restore( self, parameters: Optional[Iterable[torch.nn.Parameter]] = None @@ -150,3 +150,43 @@ class ExponentialMovingAverage: for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance(self.num_updates, int), \ + "Invalid num_updates" + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), \ + "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + self.collected_params = state_dict["collected_params"] + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" -- cgit v1.2.3