Skip to content

Commit

Permalink
rename test
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 15, 2024
1 parent 86c2e9a commit b7dd68f
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tests/core/models/test_equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _load_hydra_model():
model.backbone.num_layers = 1
return model

def test_eqv2_hydra():
def test_eqv2_hydra_activation_checkpoint():
atoms = read(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"),
index=0,
Expand All @@ -260,17 +260,18 @@ def test_eqv2_hydra():
ac_model = _load_hydra_model()
ac_model.backbone.activation_checkpoint=True

# to do this test we need both models to have the exact same state and the only
# way to do this is save the rng state and reset it after stepping the first model
start_rng_state = torch.random.get_rng_state()
outputs_no_ac = no_ac_model(inputs)
print(outputs_no_ac)
torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum())

# reset the rng state to the beginning
torch.random.set_rng_state(start_rng_state)
outptuts_ac = ac_model(inputs)
print(outptuts_ac)
torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum())

# assert all the gradients are identical between the model with checkpointing and no checkpointing
ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None}
no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None}
for name in no_ac_model_grad_dict:
Expand Down

0 comments on commit b7dd68f

Please sign in to comment.