diff options
-rw-r--r-- | .gitignore | 6 | ||||
-rw-r--r-- | LICENSE | 19 | ||||
-rw-r--r-- | README.md | 47 | ||||
-rw-r--r-- | setup.py | 24 | ||||
-rw-r--r-- | torch_ema/__init__.py | 1 | ||||
-rw-r--r-- | torch_ema/ema.py | 59 |
6 files changed, 156 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e838e9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +__pycache__/ +build/ +dist/ +.cache/ +.eggs/ +*.egg-info/ @@ -0,0 +1,19 @@ +Copyright (c) 2019 Samuel G. Fadel <samuelfadel@gmail.com> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..aeffbfe --- /dev/null +++ b/README.md @@ -0,0 +1,47 @@ +# pytorch_ema + +A very small library for computing exponential moving averages of model +parameters. + +This library was written for personal use. Nevertheless, if you run into issues +or have suggestions for improvement, feel free to open either a new issue or +pull request. + +## Example + +```python +import torch +import torch.nn.functional as F + +from torch_ema import ExponentialMovingAverage + + +x_train = torch.rand((100, 10)) +y_train = torch.rand(100).round().long() +model = torch.nn.Linear(10, 2) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) +ema = ExponentialMovingAverage(model.parameters(), decay=0.995) + +# Train for a few epochs +model.train() +for _ in range(10): + logits = model(x_train) + loss = F.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + ema.update(model.parameters()) + +# Compare losses: +# Original +model.eval() +logits = model(x_train) +loss = F.cross_entropy(logits, y_train) +print(loss.item()) + +# With EMA +ema.copy_to(model.parameters()) +logits = model(x_train) +loss = F.cross_entropy(logits, y_train) +print(loss.item()) +``` diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8d6fa9e --- /dev/null +++ b/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup, find_packages + +__version__ = '0.1' +url = 'https://github.com/fadel/pytorch_ema' +download_url = '{}/archive/{}.tar.gz'.format(url, __version__) + +install_requires = [] +setup_requires = [] +tests_require = [] + +setup( + name='torch_ema', + version=__version__, + description='PyTorch library for computing moving averages of model parameters.', + author='Samuel G. Fadel', + author_email='samuelfadel@gmail.com', + url=url, + download_url=download_url, + keywords=['pytorch', 'parameters', 'deep-learning'], + install_requires=install_requires, + setup_requires=setup_requires, + tests_require=tests_require, + packages=find_packages(), +) diff --git a/torch_ema/__init__.py b/torch_ema/__init__.py new file mode 100644 index 0000000..9732013 --- /dev/null +++ b/torch_ema/__init__.py @@ -0,0 +1 @@ +from .ema import * diff --git a/torch_ema/ema.py b/torch_ema/ema.py new file mode 100644 index 0000000..32ed7ca --- /dev/null +++ b/torch_ema/ema.py @@ -0,0 +1,59 @@ +from __future__ import division +from __future__ import unicode_literals + +import torch + + +# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + """ + def __init__(self, parameters, decay, use_num_updates=True): + """ + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the result of + `model.parameters()`. + decay: The exponential decay. + use_num_updates: Whether to use number of updates when computing + averages. + """ + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + self.shadow_params = [p.clone().detach() + for p in parameters if p.requires_grad] + + def update(self, parameters): + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. + """ + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + + def copy_to(self, parameters): + """ + Copies current parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. + """ + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + param.data.copy_(s_param.data) |