aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Fadel <samuelfadel@gmail.com>2021-07-02 13:22:55 +0200
committerGitHub <noreply@github.com>2021-07-02 13:22:55 +0200
commit72ac3d3333c3dc1d95eacdedbdb5a0132958973a (patch)
tree402dd92d2b1cf67aa7909c4eb87339b1a0acfc4c
parenta8223abaad1da1293f350d80b636a8d67b2d58a5 (diff)
parente668ae1e0a757cf8217e926be9ae228676fbe17b (diff)
Merge pull request #7 from Linux-cpp-lisp/device_and_dtype
Device and dtype improvements
-rw-r--r--tests/test_ema.py55
-rw-r--r--tests/test_state_dict.py85
-rw-r--r--torch_ema/ema.py114
3 files changed, 189 insertions, 65 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index 67a14dc..edcea4c 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -1,7 +1,5 @@
import pytest
-import copy
-
import torch
from torch_ema import ExponentialMovingAverage
@@ -138,41 +136,18 @@ def test_explicit_params():
assert not torch.all(model.weight == 0.0)
-@pytest.mark.parametrize("decay", [0.995])
-@pytest.mark.parametrize("use_num_updates", [True, False])
-@pytest.mark.parametrize("explicit_params", [True, False])
-def test_state_dict(decay, use_num_updates, explicit_params):
- model = torch.nn.Linear(10, 2, bias=False)
- with torch.no_grad():
- model.weight.fill_(0.0)
- ema = ExponentialMovingAverage(
- model.parameters(),
- decay=decay,
- use_num_updates=False
- )
- state_dict = copy.deepcopy(ema.state_dict())
-
- model2 = torch.nn.Linear(10, 2, bias=False)
- ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
- ema2.load_state_dict(state_dict)
- assert ema2.decay == decay
- assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])
-
- with torch.no_grad():
- model2.weight.fill_(1.0)
- if explicit_params:
- ema2.update(model2.parameters())
- else:
- ema2.update()
- assert torch.all(model2.weight == 1.0), "ema.update changed model weights"
-
- ema.load_state_dict(ema2.state_dict())
-
- if explicit_params:
- ema.copy_to(model.parameters())
- else:
- ema.copy_to()
- assert torch.allclose(
- model.weight,
- torch.full(size=(1,), fill_value=(1.0 - decay))
- ), "average was wrong"
+def test_to():
+ m = torch.nn.Linear(11, 3)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ assert ema.shadow_params[0].dtype == torch.get_default_dtype()
+ ema.to(dtype=torch.float16)
+ assert ema.shadow_params[0].dtype == torch.float16
+ ema.store()
+ # we store whatever we get
+ assert ema.collected_params[0].dtype == torch.get_default_dtype()
+ m = m.to(torch.float16)
+ ema.store(m.parameters())
+ assert ema.collected_params[0].dtype == torch.float16
+ ema.to(dtype=torch.float64)
+ assert ema.collected_params[0].dtype == torch.float64
+ assert ema.shadow_params[0].dtype == torch.float64
diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py
new file mode 100644
index 0000000..814f446
--- /dev/null
+++ b/tests/test_state_dict.py
@@ -0,0 +1,85 @@
+import pytest
+
+import copy
+
+import torch
+
+from torch_ema import ExponentialMovingAverage
+
+
+@pytest.mark.parametrize("decay", [0.995])
+@pytest.mark.parametrize("use_num_updates", [True, False])
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_state_dict(decay, use_num_updates, explicit_params):
+ model = torch.nn.Linear(10, 2, bias=False)
+ with torch.no_grad():
+ model.weight.fill_(0.0)
+ ema = ExponentialMovingAverage(
+ model.parameters(),
+ decay=decay,
+ use_num_updates=False
+ )
+ state_dict = copy.deepcopy(ema.state_dict())
+
+ model2 = torch.nn.Linear(10, 2, bias=False)
+ ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0)
+ ema2.load_state_dict(state_dict)
+ assert ema2.decay == decay
+ assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0])
+
+ with torch.no_grad():
+ model2.weight.fill_(1.0)
+ if explicit_params:
+ ema2.update(model2.parameters())
+ else:
+ ema2.update()
+ assert torch.all(model2.weight == 1.0), "ema.update changed model weights"
+
+ ema.load_state_dict(ema2.state_dict())
+
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
+ assert torch.allclose(
+ model.weight,
+ torch.full(size=(1,), fill_value=(1.0 - decay))
+ ), "average was wrong"
+
+
+def test_state_dict_types():
+ m1 = torch.nn.Linear(10, 2, bias=False)
+ m2 = torch.nn.Linear(10, 2, bias=False)
+ m2.to(torch.float16)
+ ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9)
+ ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9)
+ ema1.update()
+ ema2.update()
+ ema2.load_state_dict(ema1.state_dict())
+ ema1.copy_to()
+ ema2.copy_to()
+ assert m1.weight.dtype == torch.get_default_dtype()
+ assert m2.weight.dtype == torch.float16
+ assert torch.allclose(m1.weight.to(torch.float16), m2.weight)
+
+
+def test_bad_state_dict1():
+ m = torch.nn.Linear(10, 2, bias=False)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ sd = ema.state_dict()
+ sd["shadow_params"][0] = torch.zeros(3, 7)
+ # it doesn't raise at loading, since it can't know shapes.
+ ema.load_state_dict(sd)
+ with pytest.raises(RuntimeError):
+ ema.copy_to()
+ # make sure it didn't change
+ assert torch.any(m.weight.abs() > 0)
+
+
+def test_bad_state_dict2():
+ m = torch.nn.Linear(10, 2, bias=False)
+ ema = ExponentialMovingAverage(m.parameters(), decay=0.9)
+ sd = ema.state_dict()
+ sd["shadow_params"] = sd["shadow_params"][:-1]
+ with pytest.raises(ValueError):
+ ema.load_state_dict(sd)
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
index 6c0415f..3bcb465 100644
--- a/torch_ema/ema.py
+++ b/torch_ema/ema.py
@@ -4,6 +4,7 @@ from __future__ import unicode_literals
from typing import Iterable, Optional
import weakref
import copy
+import contextlib
import torch
@@ -34,7 +35,7 @@ class ExponentialMovingAverage:
parameters = list(parameters)
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
- self.collected_params = []
+ self.collected_params = None
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
# if the model goes out of scope but the ExponentialMovingAverage
@@ -59,6 +60,13 @@ class ExponentialMovingAverage:
)
return parameters
else:
+ parameters = list(parameters)
+ if len(parameters) != len(self.shadow_params):
+ raise ValueError(
+ "Number of parameters passed as argument is different "
+ "from number of shadow parameters maintained by this "
+ "ExponentialMovingAverage"
+ )
return parameters
def update(
@@ -72,10 +80,10 @@ class ExponentialMovingAverage:
the `optimizer.step()` call.
Args:
- parameters: Iterable of `torch.nn.Parameter`; usually the same set of
- parameters used to initialize this object. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
parameters = self._get_parameters(parameters)
decay = self.decay
@@ -99,13 +107,13 @@ class ExponentialMovingAverage:
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
- Copy current parameters into given collection of parameters.
+ Copy current averaged parameters into given collection of parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored moving averages. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
@@ -120,14 +128,16 @@ class ExponentialMovingAverage:
Save the current parameters for restoring later.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- temporarily stored. If `None`, the parameters of with which this
- `ExponentialMovingAverage` was initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored. If `None`, the parameters of with which this
+ `ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
- self.collected_params = [param.clone()
- for param in parameters
- if param.requires_grad]
+ self.collected_params = [
+ param.clone()
+ for param in parameters
+ if param.requires_grad
+ ]
def restore(
self,
@@ -141,16 +151,43 @@ class ExponentialMovingAverage:
restore the former parameters.
Args:
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
- updated with the stored parameters. If `None`, the
- parameters with which this `ExponentialMovingAverage` was
- initialized will be used.
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
"""
+ if self.collected_params is None:
+ raise RuntimeError(
+ "This ExponentialMovingAverage has no `store()`ed weights "
+ "to `restore()`"
+ )
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
if param.requires_grad:
param.data.copy_(c_param.data)
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.shadow_params
+ ]
+ if self.collected_params is not None:
+ self.collected_params = [
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ for p in self.collected_params
+ ]
+ return
+
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned:
@@ -178,15 +215,42 @@ class ExponentialMovingAverage:
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
+
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
+
self.collected_params = state_dict["collected_params"]
- assert isinstance(self.collected_params, list), \
- "collected_params must be a list"
- assert all(
- isinstance(p, torch.Tensor) for p in self.collected_params
- ), "collected_params must all be Tensors"
+ if self.collected_params is not None:
+ assert isinstance(self.collected_params, list), \
+ "collected_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.collected_params
+ ), "collected_params must all be Tensors"
+ assert len(self.collected_params) == len(self.shadow_params), \
+ "collected_params and shadow_params had different lengths"
+
+ if len(self.shadow_params) == len(self._params_refs):
+ # Consistant with torch.optim.Optimizer, cast things to consistant
+ # device and dtype with the parameters
+ params = [p() for p in self._params_refs]
+ # If parameters have been garbage collected, just load the state
+ # we were given without change.
+ if not any(p is None for p in params):
+ # ^ parameter references are still good
+ for i, p in enumerate(params):
+ self.shadow_params[i] = self.shadow_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ if self.collected_params is not None:
+ self.collected_params[i] = self.collected_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ else:
+ raise ValueError(
+ "Tried to `load_state_dict()` with the wrong number of "
+ "parameters in the saved state."
+ )