aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlby M <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-04-21 07:06:20 -0600
committerGitHub <noreply@github.com>2021-04-21 15:06:20 +0200
commita8223abaad1da1293f350d80b636a8d67b2d58a5 (patch)
treee06fa79877aac117cbf685d4862f567a75cc8d6c
parent3015941b5c61b686161701887a8618f5f77044bb (diff)
State dict support (#6)
-rw-r--r--tests/test_ema.py46
-rw-r--r--torch_ema/ema.py42
2 files changed, 86 insertions, 2 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index ad6ee37..67a14dc 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -1,5 +1,7 @@
import pytest
+import copy
+
import torch
from torch_ema import ExponentialMovingAverage
@@ -40,6 +42,7 @@ def test_val_error(decay, use_num_updates, explicit_params):
model.eval()
logits = model(x_val)
loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
+ print(f"Original loss: {loss_orig}")
# Validation: with EMA
# First save original parameters before replacing with EMA version
@@ -55,6 +58,7 @@ def test_val_error(decay, use_num_updates, explicit_params):
logits = model(x_val)
loss_ema = torch.nn.functional.cross_entropy(logits, y_val)
+ print(f"EMA loss: {loss_ema}")
assert loss_ema < loss_orig, "EMA loss wasn't lower"
# Test restore
@@ -131,4 +135,44 @@ def test_explicit_params():
model2.weight.fill_(1.0)
ema.update(model2.parameters())
ema.copy_to()
- assert not torch.all(model.weight == 0.0) \ No newline at end of file
+ assert not torch.all(model.weight == 0.0)
+
+
+@pytest.mark.parametrize("decay", [0.995])
+@pytest.mark.parametrize("use_num_updates", [True, False])
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_state_dict(decay, use_num_updates, explicit_params):
+ model = torch.nn.Linear(10, 2, bias=False)
+ with torch.no_grad():
+ model.weight.fill_(0.0)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=decay,
+ use_num_updates=False
+ )
+ state_dict = copy.deepcopy(ema.state_dict())
+
+ model2 = torch.nn.Linear(10, 2, bias=False)
+ ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
+ ema2.load_state_dict(state_dict)
+ assert ema2.decay == decay
+ assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])
+
+ with torch.no_grad():
+ model2.weight.fill_(1.0)
+ if explicit_params:
+ ema2.update(model2.parameters())
+ else:
+ ema2.update()
+ assert torch.all(model2.weight == 1.0), "ema.update changed model weights"
+
+ ema.load_state_dict(ema2.state_dict())
+
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
+ assert torch.allclose(
+ model.weight,
+ torch.full(size=(1,), fill_value=(1.0 - decay))
+ ), "average was wrong"
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 2e8eb6f..6c0415f 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -3,6 +3,7 @@ from __future__ import unicode_literals
from typing import Iterable, Optional
import weakref
+import copy
import torch
@@ -128,7 +129,6 @@ class ExponentialMovingAverage:
for param in parameters
if param.requires_grad]
-
def restore(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
@@ -150,3 +150,43 @@ class ExponentialMovingAverage:
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)
+
+ def state_dict(self) -> dict:
+ r"""Returns the state of the ExponentialMovingAverage as a dict."""
+ # Following PyTorch conventions, references to tensors are returned:
+ # "returns a reference to the state and not its copy!" -
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+ return {
+ "decay": self.decay,
+ "num_updates": self.num_updates,
+ "shadow_params": self.shadow_params,
+ "collected_params": self.collected_params
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the ExponentialMovingAverage state.
+
+ Args:
+ state_dict (dict): EMA state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = copy.deepcopy(state_dict)
+ self.decay = state_dict["decay"]
+ if self.decay < 0.0 or self.decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+ self.num_updates = state_dict["num_updates"]
+ assert self.num_updates is None or isinstance(self.num_updates, int), \
+ "Invalid num_updates"
+ self.shadow_params = state_dict["shadow_params"]
+ assert isinstance(self.shadow_params, list), \
+ "shadow_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.shadow_params
+ ), "shadow_params must all be Tensors"
+ self.collected_params = state_dict["collected_params"]
+ assert isinstance(self.collected_params, list), \
+ "collected_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.collected_params
+ ), "collected_params must all be Tensors"