diff --git a/tests/core/models/test_equiformer_v2.py b/tests/core/models/test_equiformer_v2.py index f3c8870a2..54d58db1c 100644 --- a/tests/core/models/test_equiformer_v2.py +++ b/tests/core/models/test_equiformer_v2.py @@ -242,7 +242,7 @@ def _load_hydra_model(): model.backbone.num_layers = 1 return model -def test_eqv2_hydra(): +def test_eqv2_hydra_activation_checkpoint(): atoms = read( os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"), index=0, @@ -260,17 +260,18 @@ def test_eqv2_hydra(): ac_model = _load_hydra_model() ac_model.backbone.activation_checkpoint=True + # to do this test we need both models to have the exact same state and the only + # way to do this is save the rng state and reset it after stepping the first model start_rng_state = torch.random.get_rng_state() outputs_no_ac = no_ac_model(inputs) - print(outputs_no_ac) torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum()) # reset the rng state to the beginning torch.random.set_rng_state(start_rng_state) outptuts_ac = ac_model(inputs) - print(outptuts_ac) torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum()) + # assert all the gradients are identical between the model with checkpointing and no checkpointing ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None} no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None} for name in no_ac_model_grad_dict: