diff options
author | Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-11-17 15:41:41 -0500 |
---|---|---|
committer | Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> | 2021-11-17 15:41:41 -0500 |
commit | 30a1306f2e5dc6aa91b4b64c2a3acd1bb3b0d7b6 (patch) | |
tree | 1204efe7c93c741136829e80628b1b42ac603e8f /tests | |
parent | 3985995e523aa25dd3cff7e7984130eef90a4282 (diff) |
handle non-trainable parameters
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_ema.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index fa90a8c..4bc1901 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -187,6 +187,27 @@ def test_explicit_params(): assert not torch.all(model.weight == 0.0) +def test_some_untrainable(): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(3)) + self.y = torch.nn.Parameter(torch.randn(3)) + self.y.requires_grad_(False) + + def forward(self, x): + return self.x * x + self.y + + model = Mod() + ema = ExponentialMovingAverage(model.parameters(), decay=0.9) + ema.update() + with torch.no_grad(): + model.x *= 1.1 + ema.update() + ema.store() + ema.copy_to() + + def test_to(): m = torch.nn.Linear(11, 3) ema = ExponentialMovingAverage(m.parameters(), decay=0.9) |