diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_ema.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index fa90a8c..4bc1901 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -187,6 +187,27 @@ def test_explicit_params(): assert not torch.all(model.weight == 0.0) +def test_some_untrainable(): + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.x = torch.nn.Parameter(torch.randn(3)) + self.y = torch.nn.Parameter(torch.randn(3)) + self.y.requires_grad_(False) + + def forward(self, x): + return self.x * x + self.y + + model = Mod() + ema = ExponentialMovingAverage(model.parameters(), decay=0.9) + ema.update() + with torch.no_grad(): + model.x *= 1.1 + ema.update() + ema.store() + ema.copy_to() + + def test_to(): m = torch.nn.Linear(11, 3) ema = ExponentialMovingAverage(m.parameters(), decay=0.9) |