diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-07-02 13:22:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-02 13:22:55 +0200 |
commit | 72ac3d3333c3dc1d95eacdedbdb5a0132958973a (patch) | |
tree | 402dd92d2b1cf67aa7909c4eb87339b1a0acfc4c | |
parent | a8223abaad1da1293f350d80b636a8d67b2d58a5 (diff) | |
parent | e668ae1e0a757cf8217e926be9ae228676fbe17b (diff) |
Merge pull request #7 from Linux-cpp-lisp/device_and_dtype
Device and dtype improvements
-rw-r--r-- | tests/test_ema.py | 55 | ||||
-rw-r--r-- | tests/test_state_dict.py | 85 | ||||
-rw-r--r-- | torch_ema/ema.py | 114 |
3 files changed, 189 insertions, 65 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index 67a14dc..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -1,7 +1,5 @@ import pytest -import copy - import torch from torch_ema import ExponentialMovingAverage @@ -138,41 +136,18 @@ def test_explicit_params(): assert not torch.all(model.weight == 0.0) -@pytest.mark.parametrize("decay", [0.995]) -@pytest.mark.parametrize("use_num_updates", [True, False]) -@pytest.mark.parametrize("explicit_params", [True, False]) -def test_state_dict(decay, use_num_updates, explicit_params): - model = torch.nn.Linear(10, 2, bias=False) - with torch.no_grad(): - model.weight.fill_(0.0) - ema = ExponentialMovingAverage( - model.parameters(), - decay=decay, - use_num_updates=False - ) - state_dict = copy.deepcopy(ema.state_dict()) - - model2 = torch.nn.Linear(10, 2, bias=False) - ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) - ema2.load_state_dict(state_dict) - assert ema2.decay == decay - assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) - - with torch.no_grad(): - model2.weight.fill_(1.0) - if explicit_params: - ema2.update(model2.parameters()) - else: - ema2.update() - assert torch.all(model2.weight == 1.0), "ema.update changed model weights" - - ema.load_state_dict(ema2.state_dict()) - - if explicit_params: - ema.copy_to(model.parameters()) - else: - ema.copy_to() - assert torch.allclose( - model.weight, - torch.full(size=(1,), fill_value=(1.0 - decay)) - ), "average was wrong" +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py new file mode 100644 index 0000000..814f446 --- /dev/null +++ b/tests/test_state_dict.py @@ -0,0 +1,85 @@ +import pytest + +import copy + +import torch + +from torch_ema import ExponentialMovingAverage + + +@pytest.mark.parametrize("decay", [0.995]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_state_dict(decay, use_num_updates, explicit_params): + model = torch.nn.Linear(10, 2, bias=False) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=False + ) + state_dict = copy.deepcopy(ema.state_dict()) + + model2 = torch.nn.Linear(10, 2, bias=False) + ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) + ema2.load_state_dict(state_dict) + assert ema2.decay == decay + assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) + + with torch.no_grad(): + model2.weight.fill_(1.0) + if explicit_params: + ema2.update(model2.parameters()) + else: + ema2.update() + assert torch.all(model2.weight == 1.0), "ema.update changed model weights" + + ema.load_state_dict(ema2.state_dict()) + + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() + assert torch.allclose( + model.weight, + torch.full(size=(1,), fill_value=(1.0 - decay)) + ), "average was wrong" + + +def test_state_dict_types(): + m1 = torch.nn.Linear(10, 2, bias=False) + m2 = torch.nn.Linear(10, 2, bias=False) + m2.to(torch.float16) + ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9) + ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9) + ema1.update() + ema2.update() + ema2.load_state_dict(ema1.state_dict()) + ema1.copy_to() + ema2.copy_to() + assert m1.weight.dtype == torch.get_default_dtype() + assert m2.weight.dtype == torch.float16 + assert torch.allclose(m1.weight.to(torch.float16), m2.weight) + + +def test_bad_state_dict1(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"][0] = torch.zeros(3, 7) + # it doesn't raise at loading, since it can't know shapes. + ema.load_state_dict(sd) + with pytest.raises(RuntimeError): + ema.copy_to() + # make sure it didn't change + assert torch.any(m.weight.abs() > 0) + + +def test_bad_state_dict2(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"] = sd["shadow_params"][:-1] + with pytest.raises(ValueError): + ema.load_state_dict(sd) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..3bcb465 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy +import contextlib import torch @@ -34,7 +35,7 @@ class ExponentialMovingAverage: parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] - self.collected_params = [] + self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage @@ -59,6 +60,13 @@ class ExponentialMovingAverage: ) return parameters else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) return parameters def update( @@ -72,10 +80,10 @@ class ExponentialMovingAverage: the `optimizer.step()` call. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay @@ -99,13 +107,13 @@ class ExponentialMovingAverage: parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ - Copy current parameters into given collection of parameters. + Copy current averaged parameters into given collection of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): @@ -120,14 +128,16 @@ class ExponentialMovingAverage: Save the current parameters for restoring later. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() - for param in parameters - if param.requires_grad] + self.collected_params = [ + param.clone() + for param in parameters + if param.requires_grad + ] def restore( self, @@ -141,16 +151,43 @@ class ExponentialMovingAverage: restore the former parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: @@ -178,15 +215,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + self.collected_params = state_dict["collected_params"] - assert isinstance(self.collected_params, list), \ - "collected_params must be a list" - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) |