aboutsummaryrefslogtreecommitdiff
path: root/tests/test_ema.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_ema.py')
-rw-r--r--tests/test_ema.py64
1 files changed, 51 insertions, 13 deletions
diff --git a/tests/test_ema.py b/tests/test_ema.py
index 6d7e43e..ad6ee37 100644
--- a/tests/test_ema.py
+++ b/tests/test_ema.py
@@ -7,7 +7,8 @@ from torch_ema import ExponentialMovingAverage
@pytest.mark.parametrize("decay", [0.995, 0.9])
@pytest.mark.parametrize("use_num_updates", [True, False])
-def test_val_error(decay, use_num_updates):
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_val_error(decay, use_num_updates, explicit_params):
"""Confirm that EMA validation error is lower than raw validation error."""
torch.manual_seed(0)
x_train = torch.rand((100, 10))
@@ -30,27 +31,37 @@ def test_val_error(decay, use_num_updates):
optimizer.zero_grad()
loss.backward()
optimizer.step()
- ema.update(model.parameters())
+ if explicit_params:
+ ema.update(model.parameters())
+ else:
+ ema.update()
# Validation: original
model.eval()
logits = model(x_val)
loss_orig = torch.nn.functional.cross_entropy(logits, y_val)
- print(f"Original loss: {loss_orig}")
# Validation: with EMA
# First save original parameters before replacing with EMA version
- ema.store(model.parameters())
+ if explicit_params:
+ ema.store(model.parameters())
+ else:
+ ema.store()
# Copy EMA parameters to model
- ema.copy_to(model.parameters())
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
logits = model(x_val)
loss_ema = torch.nn.functional.cross_entropy(logits, y_val)
- print(f"EMA loss: {loss_ema}")
assert loss_ema < loss_orig, "EMA loss wasn't lower"
# Test restore
- ema.restore(model.parameters())
+ if explicit_params:
+ ema.restore(model.parameters())
+ else:
+ ema.restore()
model.eval()
logits = model(x_val)
loss_orig2 = torch.nn.functional.cross_entropy(logits, y_val)
@@ -60,7 +71,8 @@ def test_val_error(decay, use_num_updates):
@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
@pytest.mark.parametrize("use_num_updates", [True, False])
-def test_store_restore(decay, use_num_updates):
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_store_restore(decay, use_num_updates, explicit_params):
model = torch.nn.Linear(10, 2)
ema = ExponentialMovingAverage(
model.parameters(),
@@ -68,15 +80,22 @@ def test_store_restore(decay, use_num_updates):
use_num_updates=use_num_updates
)
orig_weight = model.weight.clone().detach()
- ema.store(model.parameters())
+ if explicit_params:
+ ema.store(model.parameters())
+ else:
+ ema.store()
with torch.no_grad():
model.weight.uniform_(0.0, 1.0)
- ema.restore(model.parameters())
+ if explicit_params:
+ ema.restore(model.parameters())
+ else:
+ ema.restore()
assert torch.all(model.weight == orig_weight)
@pytest.mark.parametrize("decay", [0.995, 0.9, 0.0, 1.0])
-def test_update(decay):
+@pytest.mark.parametrize("explicit_params", [True, False])
+def test_update(decay, explicit_params):
model = torch.nn.Linear(10, 2, bias=False)
with torch.no_grad():
model.weight.fill_(0.0)
@@ -87,10 +106,29 @@ def test_update(decay):
)
with torch.no_grad():
model.weight.fill_(1.0)
- ema.update(model.parameters())
+ if explicit_params:
+ ema.update(model.parameters())
+ else:
+ ema.update()
assert torch.all(model.weight == 1.0), "ema.update changed model weights"
- ema.copy_to(model.parameters())
+ if explicit_params:
+ ema.copy_to(model.parameters())
+ else:
+ ema.copy_to()
assert torch.allclose(
model.weight,
torch.full(size=(1,), fill_value=(1.0 - decay))
), "average was wrong"
+
+
+def test_explicit_params():
+ model = torch.nn.Linear(10, 2)
+ with torch.no_grad():
+ model.weight.fill_(0.0)
+ ema = ExponentialMovingAverage(model.parameters(), decay=0.9)
+ model2 = torch.nn.Linear(10, 2)
+ with torch.no_grad():
+ model2.weight.fill_(1.0)
+ ema.update(model2.parameters())
+ ema.copy_to()
+ assert not torch.all(model.weight == 0.0) \ No newline at end of file