From 3015941b5c61b686161701887a8618f5f77044bb Mon Sep 17 00:00:00 2001 From: Alby M <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 20 Apr 2021 12:16:39 -0600 Subject: Option to keep parameters reference in `ExponentialMovingAverage` (#5) --- tests/test_ema.py | 64 ++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 13 deletions(-) (limited to 'tests') diff --git a/tests/test_ema.py b/tests/test_ema.py index 6d7e43e..ad6ee37 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -7,7 +7,8 @@ from torch_ema import ExponentialMovingAverage @pytest.mark.parametrize("decay", [0.995, 0.9]) @pytest.mark.parametrize("use_num_updates", [True, False]) -def test_val_error(decay, use_num_updates): +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_val_error(decay, use_num_updates, explicit_params): """Confirm that EMA validation error is lower than raw validation error.""" torch.manual_seed(0) x_train = torch.rand((100, 10)) @@ -30,27 +31,37 @@ def test_val_error(decay, use_num_updates): optimizer.zero_grad() loss.backward() optimizer.step() - ema.update(model.parameters()) + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() # Validation: original model.eval() logits = model(x_val) loss_orig = torch.nn.functional.cross_entropy(logits, y_val) - print(f"Original loss: {loss_orig}") # Validation: with EMA # First save original parameters before replacing with EMA version - ema.store(model.parameters()) + if explicit_params: + ema.store(model.parameters()) + else: + ema.store() # Copy EMA parameters to model - ema.copy_to(model.parameters()) + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() logits = model(x_val) loss_ema = torch.nn.functional.cross_entropy(logits, y_val) - print(f"EMA loss: {loss_ema}") assert loss_ema < loss_orig, "EMA loss wasn't lower" # Test restore - ema.restore(model.parameters()) + if explicit_params: + ema.restore(model.parameters()) + else: + ema.restore() model.eval() logits = model(x_val) loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val) @@ -60,7 +71,8 @@ def test_val_error(decay, use_num_updates): @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) @pytest.mark.parametrize("use_num_updates", [True, False]) -def test_store_restore(decay, use_num_updates): +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_store_restore(decay, use_num_updates, explicit_params): model = torch.nn.Linear(10, 2) ema = ExponentialMovingAverage( model.parameters(), @@ -68,15 +80,22 @@ def test_store_restore(decay, use_num_updates): use_num_updates=use_num_updates ) orig_weight = model.weight.clone().detach() - ema.store(model.parameters()) + if explicit_params: + ema.store(model.parameters()) + else: + ema.store() with torch.no_grad(): model.weight.uniform_(0.0, 1.0) - ema.restore(model.parameters()) + if explicit_params: + ema.restore(model.parameters()) + else: + ema.restore() assert torch.all(model.weight == orig_weight) @pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0]) -def test_update(decay): +@pytest.mark.parametrize("explicit_params", [True, False]) +def test_update(decay, explicit_params): model = torch.nn.Linear(10, 2, bias=False) with torch.no_grad(): model.weight.fill_(0.0) @@ -87,10 +106,29 @@ def test_update(decay): ) with torch.no_grad(): model.weight.fill_(1.0) - ema.update(model.parameters()) + if explicit_params: + ema.update(model.parameters()) + else: + ema.update() assert torch.all(model.weight == 1.0), "ema.update changed model weights" - ema.copy_to(model.parameters()) + if explicit_params: + ema.copy_to(model.parameters()) + else: + ema.copy_to() assert torch.allclose( model.weight, torch.full(size=(1,), fill_value=(1.0 - decay)) ), "average was wrong" + + +def test_explicit_params(): + model = torch.nn.Linear(10, 2) + with torch.no_grad(): + model.weight.fill_(0.0) + ema = ExponentialMovingAverage(model.parameters(), decay=0.9) + model2 = torch.nn.Linear(10, 2) + with torch.no_grad(): + model2.weight.fill_(1.0) + ema.update(model2.parameters()) + ema.copy_to() + assert not torch.all(model.weight == 0.0) \ No newline at end of file -- cgit v1.2.3