From 984e7d5fbc8eb916a24086a7053335516c14ff34 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 12:49:49 -0600 Subject: Cast state dict to current device/dtype --- torch_ema/ema.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'torch_ema/ema.py') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 6c0415f..f6d8f6e 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -178,15 +178,42 @@ class ExponentialMovingAverage: self.num_updates = state_dict["num_updates"] assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" + + # Consistant with torch.optim.Optimizer, cast things to current + # device and dtype + if len(self.shadow_params) > 0: + device = self.shadow_params[0].device + dtype = self.shadow_params[0].dtype + else: + device = None + dtype = None + self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" + # Cast shadow params: + if device is not None: + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + self.collected_params = state_dict["collected_params"] assert isinstance(self.collected_params, list), \ "collected_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.collected_params ), "collected_params must all be Tensors" + # Cast collected params: + if device is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] -- cgit v1.2.3 From bf6d797c31b35b846c072618c2c8631feeb6db38 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:22:26 -0600 Subject: Casting and error checking --- torch_ema/ema.py | 84 +++++++++++++++++++++++++++++++++----------------------- 1 file changed, 49 insertions(+), 35 deletions(-) (limited to 'torch_ema/ema.py') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index f6d8f6e..b3487cf 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -34,7 +34,7 @@ class ExponentialMovingAverage: parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] - self.collected_params = [] + self.collected_params = None # By maintaining only a weakref to each parameter, # we maintain the old GC behaviour of ExponentialMovingAverage: # if the model goes out of scope but the ExponentialMovingAverage @@ -59,6 +59,13 @@ class ExponentialMovingAverage: ) return parameters else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) return parameters def update( @@ -99,7 +106,7 @@ class ExponentialMovingAverage: parameters: Optional[Iterable[torch.nn.Parameter]] = None ) -> None: """ - Copy current parameters into given collection of parameters. + Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be @@ -125,9 +132,11 @@ class ExponentialMovingAverage: `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) - self.collected_params = [param.clone() - for param in parameters - if param.requires_grad] + self.collected_params = [ + param.clone() + for param in parameters + if param.requires_grad + ] def restore( self, @@ -146,6 +155,11 @@ class ExponentialMovingAverage: parameters with which this `ExponentialMovingAverage` was initialized will be used. """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) parameters = self._get_parameters(parameters) for c_param, param in zip(self.collected_params, parameters): if param.requires_grad: @@ -179,41 +193,41 @@ class ExponentialMovingAverage: assert self.num_updates is None or isinstance(self.num_updates, int), \ "Invalid num_updates" - # Consistant with torch.optim.Optimizer, cast things to current - # device and dtype - if len(self.shadow_params) > 0: - device = self.shadow_params[0].device - dtype = self.shadow_params[0].dtype - else: - device = None - dtype = None - self.shadow_params = state_dict["shadow_params"] assert isinstance(self.shadow_params, list), \ "shadow_params must be a list" assert all( isinstance(p, torch.Tensor) for p in self.shadow_params ), "shadow_params must all be Tensors" - # Cast shadow params: - if device is not None: - self.shadow_params = [ - p.to(device=device, dtype=dtype) - if p.is_floating_point() - else p.to(device=device) - for p in self.shadow_params - ] self.collected_params = state_dict["collected_params"] - assert isinstance(self.collected_params, list), \ - "collected_params must be a list" - assert all( - isinstance(p, torch.Tensor) for p in self.collected_params - ), "collected_params must all be Tensors" - # Cast collected params: - if device is not None: - self.collected_params = [ - p.to(device=device, dtype=dtype) - if p.is_floating_point() - else p.to(device=device) - for p in self.collected_params - ] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) -- cgit v1.2.3 From 81120309acff37307b2226fbda12277ca1662f93 Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 13:42:20 -0600 Subject: Add .to() --- tests/test_ema.py | 17 +++++++++++++++++ torch_ema/ema.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+) (limited to 'torch_ema/ema.py') diff --git a/tests/test_ema.py b/tests/test_ema.py index aa43b14..edcea4c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -134,3 +134,20 @@ def test_explicit_params(): ema.update(model2.parameters()) ema.copy_to() assert not torch.all(model.weight == 0.0) + + +def test_to(): + m = torch.nn.Linear(11, 3) + ema = ExponentialMovingAverage(m.parameters(), decay=0.9) + assert ema.shadow_params[0].dtype == torch.get_default_dtype() + ema.to(dtype=torch.float16) + assert ema.shadow_params[0].dtype == torch.float16 + ema.store() + # we store whatever we get + assert ema.collected_params[0].dtype == torch.get_default_dtype() + m = m.to(torch.float16) + ema.store(m.parameters()) + assert ema.collected_params[0].dtype == torch.float16 + ema.to(dtype=torch.float64) + assert ema.collected_params[0].dtype == torch.float64 + assert ema.shadow_params[0].dtype == torch.float64 diff --git a/torch_ema/ema.py b/torch_ema/ema.py index b3487cf..2aa3004 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -165,6 +165,28 @@ class ExponentialMovingAverage: if param.requires_grad: param.data.copy_(c_param.data) + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + def state_dict(self) -> dict: r"""Returns the state of the ExponentialMovingAverage as a dict.""" # Following PyTorch conventions, references to tensors are returned: -- cgit v1.2.3 From e668ae1e0a757cf8217e926be9ae228676fbe17b Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 21 Apr 2021 14:04:10 -0600 Subject: Indents --- torch_ema/ema.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) (limited to 'torch_ema/ema.py') diff --git a/torch_ema/ema.py b/torch_ema/ema.py index 2aa3004..3bcb465 100644 --- a/torch_ema/ema.py +++ b/torch_ema/ema.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals from typing import Iterable, Optional import weakref import copy +import contextlib import torch @@ -79,10 +80,10 @@ class ExponentialMovingAverage: the `optimizer.step()` call. Args: - parameters: Iterable of `torch.nn.Parameter`; usually the same set of - parameters used to initialize this object. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) decay = self.decay @@ -109,10 +110,10 @@ class ExponentialMovingAverage: Copy current averaged parameters into given collection of parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ parameters = self._get_parameters(parameters) for s_param, param in zip(self.shadow_params, parameters): @@ -127,9 +128,9 @@ class ExponentialMovingAverage: Save the current parameters for restoring later. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - temporarily stored. If `None`, the parameters of with which this - `ExponentialMovingAverage` was initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. """ parameters = self._get_parameters(parameters) self.collected_params = [ @@ -150,10 +151,10 @@ class ExponentialMovingAverage: restore the former parameters. Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored parameters. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. """ if self.collected_params is None: raise RuntimeError( -- cgit v1.2.3