Skip to content

Commit

Permalink
Merge branch 'main' into refactor_equiv2
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Aug 15, 2024
2 parents 0d28b7e + ef2a4bc commit cf291da
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 15 deletions.
56 changes: 41 additions & 15 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(
avg_degree: float | None = None,
use_energy_lin_ref: bool | None = False,
load_energy_lin_ref: bool | None = False,
activation_checkpoint: bool | None = False,
):
if mmax_list is None:
mmax_list = [2]
Expand All @@ -174,6 +175,7 @@ def __init__(
logging.error("You need to install e3nn==0.4.4 to use EquiformerV2.")
raise ImportError

self.activation_checkpoint = activation_checkpoint
self.use_pbc = use_pbc
self.use_pbc_single = use_pbc_single
self.regress_forces = regress_forces
Expand Down Expand Up @@ -481,14 +483,26 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
###############################################################

for i in range(self.num_layers):
x = self.blocks[i](
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
batch=data_batch, # for GraphDropPath
node_offset=graph.node_offset,
)
if self.activation_checkpoint:
x = torch.utils.checkpoint.checkpoint(
self.blocks[i],
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
data_batch, # for GraphDropPath
graph.node_offset,
use_reentrant=not self.training,
)
else:
x = self.blocks[i](
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
batch=data_batch, # for GraphDropPath
node_offset=graph.node_offset,
)

# Final layer norm
x.embedding = self.norm(x.embedding)
Expand Down Expand Up @@ -633,6 +647,7 @@ class EquiformerV2ForceHead(nn.Module, HeadInterface):
def __init__(self, backbone):
super().__init__()

self.activation_checkpoint = backbone.activation_checkpoint
self.force_block = SO2EquivariantGraphAttention(
backbone.sphere_channels,
backbone.attn_hidden_channels,
Expand Down Expand Up @@ -660,13 +675,24 @@ def __init__(self, backbone):
self.apply(backbone._uniform_init_rad_func_linear_weights)

def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
forces = self.force_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
if self.activation_checkpoint:
forces = torch.utils.checkpoint.checkpoint(
self.force_block,
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
emb["graph"].node_offset,
use_reentrant=not self.training,
)
else:
forces = self.force_block(
emb["node_embedding"],
emb["graph"].atomic_numbers_full,
emb["graph"].edge_distance,
emb["graph"].edge_index,
node_offset=emb["graph"].node_offset,
)
forces = forces.embedding.narrow(1, 1, 3)
forces = forces.view(-1, 3).contiguous()
if gp_utils.initialized():
Expand Down
46 changes: 46 additions & 0 deletions tests/core/models/test_equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import copy
import io
import os
from pathlib import Path

import pytest
import requests
import torch
import yaml
from ase.io import read
from torch.nn.parallel.distributed import DistributedDataParallel

Expand Down Expand Up @@ -230,3 +232,47 @@ def sign(x):
embedding._l_primary(c)
lp = embedding.embedding.clone()
(test_matrix_lp == lp).all()


def _load_hydra_model():
torch.manual_seed(4)
with open(Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml")) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
model = registry.get_model_class("hydra")(yaml_config["model"]["backbone"],yaml_config["model"]["heads"])
model.backbone.num_layers = 1
return model

def test_eqv2_hydra_activation_checkpoint():
atoms = read(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "atoms.json"),
index=0,
format="json",
)
a2g = AtomsToGraphs(
max_neigh=200,
radius=6,
r_edges=False,
r_fixed=True,
)
data_list = a2g.convert_all([atoms])
inputs = data_list_collater(data_list)
no_ac_model = _load_hydra_model()
ac_model = _load_hydra_model()
ac_model.backbone.activation_checkpoint=True

# to do this test we need both models to have the exact same state and the only
# way to do this is save the rng state and reset it after stepping the first model
start_rng_state = torch.random.get_rng_state()
outputs_no_ac = no_ac_model(inputs)
torch.autograd.backward(outputs_no_ac["energy"].sum() + outputs_no_ac["forces"].sum())

# reset the rng state to the beginning
torch.random.set_rng_state(start_rng_state)
outptuts_ac = ac_model(inputs)
torch.autograd.backward(outptuts_ac["energy"].sum() + outptuts_ac["forces"].sum())

# assert all the gradients are identical between the model with checkpointing and no checkpointing
ac_model_grad_dict = {name:p.grad for name, p in ac_model.named_parameters() if p.grad is not None}
no_ac_model_grad_dict = {name:p.grad for name, p in no_ac_model.named_parameters() if p.grad is not None}
for name in no_ac_model_grad_dict:
assert torch.allclose(no_ac_model_grad_dict[name], ac_model_grad_dict[name], atol=1e-4)

0 comments on commit cf291da

Please sign in to comment.