aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ema.py
diff options
context:
space:
mode:
authorLinux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-11-17 15:41:41 -0500
committerLinux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com>2021-11-17 15:41:41 -0500
commit30a1306f2e5dc6aa91b4b64c2a3acd1bb3b0d7b6 (patch)
tree1204efe7c93c741136829e80628b1b42ac603e8f /tests/test_ema.py
parent3985995e523aa25dd3cff7e7984130eef90a4282 (diff)
handle non-trainable parameters
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r--tests/test_ema.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index fa90a8c..4bc1901 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -187,6 +187,27 @@ def test_explicit_params():
assert not torch.all(model.weight == 0.0)
+def test_some_untrainable():
+ class Mod(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.x = torch.nn.Parameter(torch.randn(3))
+ self.y = torch.nn.Parameter(torch.randn(3))
+ self.y.requires_grad_(False)
+
+ def forward(self, x):
+ return self.x * x + self.y
+
+ model = Mod()
+ ema = ExponentialMovingAverage(model.parameters(), decay=0.9)
+ ema.update()
+ with torch.no_grad():
+ model.x *= 1.1
+ ema.update()
+ ema.store()
+ ema.copy_to()
+
+
def test_to():
m = torch.nn.Linear(11, 3)
ema = ExponentialMovingAverage(m.parameters(), decay=0.9)