diff options
author | Samuel Fadel <samuelfadel@gmail.com> | 2021-07-02 13:23:07 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-07-02 13:23:07 +0200 |
commit | 3985995e523aa25dd3cff7e7984130eef90a4282 (patch) | |
tree | da848e06c121f731542b969bbe6d576dd5304369 /tests | |
parent | 72ac3d3333c3dc1d95eacdedbdb5a0132958973a (diff) | |
parent | 98758f465aa319c0880cc948f34d1b59e8dd4550 (diff) |
Merge pull request #8 from Linux-cpp-lisp/context_manager
Context manager & README updates
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_ema.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py index edcea4c..fa90a8c 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -71,6 +71,57 @@ def test_val_error(decay, use_num_updates, explicit_params): "Restored model wasn't the same as stored model" +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_contextmanager(explicit_params): + """Confirm that EMA validation error is lower than raw validation error.""" + torch.manual_seed(0) + x_train = torch.rand((100, 10)) + y_train = torch.rand(100).round().long() + x_val = torch.rand((100, 10)) + y_val = torch.rand(100).round().long() + model = torch.nn.Linear(10, 2) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) + ema = ExponentialMovingAverage( + model.parameters(), + decay=0.99, + ) + + # Train for a few epochs + model.train() + for _ in range(20): + logits = model(x_train) + loss = torch.nn.functional.cross_entropy(logits, y_train) + optimizer.zero_grad() + loss.backward() + optimizer.step() + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() + + final_weight = model.weight.clone().detach() + + # Validation: original + model.eval() + logits = model(x_val) + loss_orig = torch.nn.functional.cross_entropy(logits, y_val) + print(f"Original loss: {loss_orig}") + + # Validation: with EMA + if explicit_params: + cm = ema.average_parameters(model.parameters()) + else: + cm = ema.average_parameters() + + with cm: + logits = model(x_val) + loss_ema = torch.nn.functional.cross_entropy(logits, y_val) + + print(f"EMA loss: {loss_ema}") + assert loss_ema < loss_orig, "EMA loss wasn't lower" + assert torch.all(model.weight == final_weight), "Restore failed" + + @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) @pytest.mark.parametrize("use_num_updates", [True, False]) @pytest.mark.parametrize("explicit_params", [True, False]) |