aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore6
-rw-r--r--LICENSE19
-rw-r--r--README.md47
-rw-r--r--setup.py24
-rw-r--r--torch_ema/__init__.py1
-rw-r--r--torch_ema/ema.py59
6 files changed, 156 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..9e838e9
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+__pycache__/
+build/
+dist/
+.cache/
+.eggs/
+*.egg-info/
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..dffb620
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,19 @@
+Copyright (c) 2019 Samuel G. Fadel <samuelfadel@gmail.com>
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..aeffbfe
--- /dev/null
+++ b/README.md
@@ -0,0 +1,47 @@
+# pytorch_ema
+
+A very small library for computing exponential moving averages of model
+parameters.
+
+This library was written for personal use. Nevertheless, if you run into issues
+or have suggestions for improvement, feel free to open either a new issue or
+pull request.
+
+## Example
+
+```python
+import torch
+import torch.nn.functional as F
+
+from torch_ema import ExponentialMovingAverage
+
+
+x_train = torch.rand((100, 10))
+y_train = torch.rand(100).round().long()
+model = torch.nn.Linear(10, 2)
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
+ema = ExponentialMovingAverage(model.parameters(), decay=0.995)
+
+# Train for a few epochs
+model.train()
+for _ in range(10):
+ logits = model(x_train)
+ loss = F.cross_entropy(logits, y_train)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ ema.update(model.parameters())
+
+# Compare losses:
+# Original
+model.eval()
+logits = model(x_train)
+loss = F.cross_entropy(logits, y_train)
+print(loss.item())
+
+# With EMA
+ema.copy_to(model.parameters())
+logits = model(x_train)
+loss = F.cross_entropy(logits, y_train)
+print(loss.item())
+```
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..8d6fa9e
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,24 @@
+from setuptools import setup, find_packages
+
+__version__ = '0.1'
+url = 'https://github.com/fadel/pytorch_ema'
+download_url = '{}/archive/{}.tar.gz'.format(url, __version__)
+
+install_requires = []
+setup_requires = []
+tests_require = []
+
+setup(
+ name='torch_ema',
+ version=__version__,
+ description='PyTorch library for computing moving averages of model parameters.',
+ author='Samuel G. Fadel',
+ author_email='samuelfadel@gmail.com',
+ url=url,
+ download_url=download_url,
+ keywords=['pytorch', 'parameters', 'deep-learning'],
+ install_requires=install_requires,
+ setup_requires=setup_requires,
+ tests_require=tests_require,
+ packages=find_packages(),
+)
diff --git a/torch_ema/__init__.py b/torch_ema/__init__.py
new file mode 100644
index 0000000..9732013
--- /dev/null
+++ b/torch_ema/__init__.py
@@ -0,0 +1 @@
+from .ema import *
diff --git a/torch_ema/ema.py b/torch_ema/ema.py
new file mode 100644
index 0000000..32ed7ca
--- /dev/null
+++ b/torch_ema/ema.py
@@ -0,0 +1,59 @@
+from __future__ import division
+from __future__ import unicode_literals
+
+import torch
+
+
+# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
+class ExponentialMovingAverage:
+ """
+ Maintains (exponential) moving average of a set of parameters.
+ """
+ def __init__(self, parameters, decay, use_num_updates=True):
+ """
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
+ `model.parameters()`.
+ decay: The exponential decay.
+ use_num_updates: Whether to use number of updates when computing
+ averages.
+ """
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+ self.decay = decay
+ self.num_updates = 0 if use_num_updates else None
+ self.shadow_params = [p.clone().detach()
+ for p in parameters if p.requires_grad]
+
+ def update(self, parameters):
+ """
+ Update currently maintained parameters.
+
+ Call this every time the parameters are updated, such as the result of
+ the `optimizer.step()` call.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object.
+ """
+ decay = self.decay
+ if self.num_updates is not None:
+ self.num_updates += 1
+ decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
+ one_minus_decay = 1.0 - decay
+ with torch.no_grad():
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ s_param.sub_(one_minus_decay * (s_param - param))
+
+ def copy_to(self, parameters):
+ """
+ Copies current parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages.
+ """
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ param.data.copy_(s_param.data)