aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_ema.py21
1 files changed, 21 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index fa90a8c..4bc1901 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -187,6 +187,27 @@ def test_explicit_params():
assert not torch.all(model.weight == 0.0)
+def test_some_untrainable():
+ class Mod(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.x = torch.nn.Parameter(torch.randn(3))
+ self.y = torch.nn.Parameter(torch.randn(3))
+ self.y.requires_grad_(False)
+
+ def forward(self, x):
+ return self.x * x + self.y
+
+ model = Mod()
+ ema = ExponentialMovingAverage(model.parameters(), decay=0.9)
+ ema.update()
+ with torch.no_grad():
+ model.x *= 1.1
+ ema.update()
+ ema.store()
+ ema.copy_to()
+
+
def test_to():
m = torch.nn.Linear(11, 3)
ema = ExponentialMovingAverage(m.parameters(), decay=0.9)