From 30a1306f2e5dc6aa91b4b64c2a3acd1bb3b0d7b6 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Wed, 17 Nov 2021 15:41:41 -0500 Subject: handle non-trainable parameters --- tests/test_ema.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'tests') diff --git a/tests/test_ema.py b/tests/test_ema.py index fa90a8c..4bc1901 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -187,6 +187,27 @@ def test_explicit_params(): assert not torch.all(model.weight == 0.0) +def test_some_untrainable(): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(3)) + self.y = torch.nn.Parameter(torch.randn(3)) + self.y.requires_grad_(False) + + def forward(self, x): + return self.x * x + self.y + + model = Mod() + ema = ExponentialMovingAverage(model.parameters(), decay=0.9) + ema.update() + with torch.no_grad(): + model.x *= 1.1 + ema.update() + ema.store() + ema.copy_to() + + def test_to(): m = torch.nn.Linear(11, 3) ema = ExponentialMovingAverage(m.parameters(), decay=0.9) -- cgit v1.2.3