From 984e7d5fbc8eb916a24086a7053335516c14ff34 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 12:49:49 -0600 Subject: Cast state dict to current device/dtype --- torch_ema/ema.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..f6d8f6e 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -178,15 +178,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + + # Consistant with torch.optim.Optimizer, cast things to current + # device and dtype + if len(self.shadow_params) > 0: + device = self.shadow_params[0].device + dtype = self.shadow_params[0].dtype + else: + device = None + dtype = None + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + # Cast shadow params: + if device is not None: + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + self.collected_params = state_dict["collected_params"] assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" + # Cast collected params: + if device is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] -- cgit v1.2.3 From bf6d797c31b35b846c072618c2c8631feeb6db38 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:22:26 -0600 Subject: Casting and error checking --- torch_ema/ema.py | 84 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 49 insertions(+), 35 deletions(-) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index f6d8f6e..b3487cf 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -34,7 +34,7 @@ class ExponentialMovingAverage: parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] - self.collected_params = [] + self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage @@ -59,6 +59,13 @@ class ExponentialMovingAverage: ) return parameters else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) return parameters def update( @@ -99,7 +106,7 @@ class ExponentialMovingAverage: parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ - Copy current parameters into given collection of parameters. + Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be @@ -125,9 +132,11 @@ class ExponentialMovingAverage: `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() - for param in parameters - if param.requires_grad] + self.collected_params = [ + param.clone() + for param in parameters + if param.requires_grad + ] def restore( self, @@ -146,6 +155,11 @@ class ExponentialMovingAverage: parameters with which this `ExponentialMovingAverage` was initialized will be used. """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: @@ -179,41 +193,41 @@ class ExponentialMovingAverage: assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" - # Consistant with torch.optim.Optimizer, cast things to current - # device and dtype - if len(self.shadow_params) > 0: - device = self.shadow_params[0].device - dtype = self.shadow_params[0].dtype - else: - device = None - dtype = None - self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" - # Cast shadow params: - if device is not None: - self.shadow_params = [ - p.to(device=device, dtype=dtype) - if p.is_floating_point() - else p.to(device=device) - for p in self.shadow_params - ] self.collected_params = state_dict["collected_params"] - assert isinstance(self.collected_params, list), \ - "collected_params must be a list" - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" - # Cast collected params: - if device is not None: - self.collected_params = [ - p.to(device=device, dtype=dtype) - if p.is_floating_point() - else p.to(device=device) - for p in self.collected_params - ] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) -- cgit v1.2.3 From 81a99ed1ec6f576d6b8004c7000ca0bc023e7483 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:35:20 -0600 Subject: More state_dict tests --- tests/test_ema.py | 42 ------------------------ tests/test_state_dict.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+), 42 deletions(-) create mode 100644 tests/test_state_dict.py diff --git a/tests/test_ema.py b/tests/test_ema.py index 67a14dc..aa43b14 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -1,7 +1,5 @@ import pytest -import copy - import torch from torch_ema import ExponentialMovingAverage @@ -136,43 +134,3 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) - - -@pytest.mark.parametrize("decay", [0.995]) -@pytest.mark.parametrize("use_num_updates", [True, False]) -@pytest.mark.parametrize("explicit_params", [True, False]) -def test_state_dict(decay, use_num_updates, explicit_params): - model = torch.nn.Linear(10, 2, bias=False) - with torch.no_grad(): - model.weight.fill_(0.0) - ema = ExponentialMovingAverage( - model.parameters(), - decay=decay, - use_num_updates=False - ) - state_dict = copy.deepcopy(ema.state_dict()) - - model2 = torch.nn.Linear(10, 2, bias=False) - ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) - ema2.load_state_dict(state_dict) - assert ema2.decay == decay - assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) - - with torch.no_grad(): - model2.weight.fill_(1.0) - if explicit_params: - ema2.update(model2.parameters()) - else: - ema2.update() - assert torch.all(model2.weight == 1.0), "ema.update changed model weights" - - ema.load_state_dict(ema2.state_dict()) - - if explicit_params: - ema.copy_to(model.parameters()) - else: - ema.copy_to() - assert torch.allclose( - model.weight, - torch.full(size=(1,), fill_value=(1.0 - decay)) - ), "average was wrong" diff --git a/tests/test_state_dict.py b/tests/test_state_dict.py new file mode 100644 index 0000000..814f446 --- /dev/null +++ b/tests/test_state_dict.py @@ -0,0 +1,85 @@ +import pytest + +import copy + +import torch + +from torch_ema import ExponentialMovingAverage + + +@pytest.mark.parametrize("decay", [0.995]) +@pytest.mark.parametrize("use_num_updates", [True, False]) +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_state_dict(decay, use_num_updates, explicit_params): + model = torch.nn.Linear(10, 2, bias=False) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage( + model.parameters(), + decay=decay, + use_num_updates=False + ) + state_dict = copy.deepcopy(ema.state_dict()) + + model2 = torch.nn.Linear(10, 2, bias=False) + ema2 = ExponentialMovingAverage(model2.parameters(), decay=0.0) + ema2.load_state_dict(state_dict) + assert ema2.decay == decay + assert torch.allclose(ema2.shadow_params[0], ema.shadow_params[0]) + + with torch.no_grad(): + model2.weight.fill_(1.0) + if explicit_params: + ema2.update(model2.parameters()) + else: + ema2.update() + assert torch.all(model2.weight == 1.0), "ema.update changed model weights" + + ema.load_state_dict(ema2.state_dict()) + + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() + assert torch.allclose( + model.weight, + torch.full(size=(1,), fill_value=(1.0 - decay)) + ), "average was wrong" + + +def test_state_dict_types(): + m1 = torch.nn.Linear(10, 2, bias=False) + m2 = torch.nn.Linear(10, 2, bias=False) + m2.to(torch.float16) + ema1 = ExponentialMovingAverage(m1.parameters(), decay=0.9) + ema2 = ExponentialMovingAverage(m2.parameters(), decay=0.9) + ema1.update() + ema2.update() + ema2.load_state_dict(ema1.state_dict()) + ema1.copy_to() + ema2.copy_to() + assert m1.weight.dtype == torch.get_default_dtype() + assert m2.weight.dtype == torch.float16 + assert torch.allclose(m1.weight.to(torch.float16), m2.weight) + + +def test_bad_state_dict1(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"][0] = torch.zeros(3, 7) + # it doesn't raise at loading, since it can't know shapes. + ema.load_state_dict(sd) + with pytest.raises(RuntimeError): + ema.copy_to() + # make sure it didn't change + assert torch.any(m.weight.abs() > 0) + + +def test_bad_state_dict2(): + m = torch.nn.Linear(10, 2, bias=False) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + sd = ema.state_dict() + sd["shadow_params"] = sd["shadow_params"][:-1] + with pytest.raises(ValueError): + ema.load_state_dict(sd) -- cgit v1.2.3 From 81120309acff37307b2226fbda12277ca1662f93 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:42:20 -0600 Subject: Add .to() --- tests/test_ema.py | 17 +++++++++++++++++ torch_ema/ema.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/tests/test_ema.py b/tests/test_ema.py index aa43b14..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -134,3 +134,20 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) + + +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 diff --git a/torch_ema/ema.py b/torch_ema/ema.py index b3487cf..2aa3004 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -165,6 +165,28 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: -- cgit v1.2.3 From e668ae1e0a757cf8217e926be9ae228676fbe17b Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 14:04:10 -0600 Subject: Indents --- torch_ema/ema.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 2aa3004..3bcb465 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy +import contextlib import torch @@ -79,10 +80,10 @@ class ExponentialMovingAverage: the `optimizer.step()` call. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay @@ -109,10 +110,10 @@ class ExponentialMovingAverage: Copy current averaged parameters into given collection of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): @@ -127,9 +128,9 @@ class ExponentialMovingAverage: Save the current parameters for restoring later. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.collected_params = [ @@ -150,10 +151,10 @@ class ExponentialMovingAverage: restore the former parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ if self.collected_params is None: raise RuntimeError( -- cgit v1.2.3