Skip to content

Commit

Permalink
refactor and deprecate old equiformerv2
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Aug 15, 2024
1 parent f19c51b commit 57763bc
Show file tree
Hide file tree
Showing 5 changed files with 852 additions and 270 deletions.
5 changes: 3 additions & 2 deletions src/fairchem/core/models/equiformer_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from .equiformer_v2 import EquiformerV2
from .equiformer_v2_deprecated import EquiformerV2
from .equiformer_v2 import EquiformerV2BackboneAndHeads

__all__ = ["EquiformerV2"]
__all__ = ["EquiformerV2", "EquiformerV2BackboneAndHeads"]
326 changes: 58 additions & 268 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from fairchem.core.common import gp_utils
from fairchem.core.common.registry import registry
from fairchem.core.common.utils import conditional_grad
from fairchem.core.models.base import BackboneInterface, GraphModelMixin, HeadInterface
from fairchem.core.models.base import (
BackboneInterface,
GraphModelMixin,
HeadInterface,
HydraModel,
)
from fairchem.core.models.scn.smearing import GaussianSmearing

with contextlib.suppress(ImportError):
Expand Down Expand Up @@ -54,8 +59,8 @@
_AVG_DEGREE = 23.395238876342773 # IS2RE: 100k, max_radius = 5, max_neighbors = 100


@registry.register_model("equiformer_v2")
class EquiformerV2(nn.Module, GraphModelMixin):
@registry.register_model("equiformer_v2_backbone")
class EquiformerV2Backbone(nn.Module, BackboneInterface, GraphModelMixin):
"""
Equiformer with graph attention built upon SO(2) convolution and feedforward network built upon S2 activation
Expand Down Expand Up @@ -355,43 +360,6 @@ def __init__(
lmax=max(self.lmax_list),
num_channels=self.sphere_channels,
)
self.energy_block = FeedForwardNetwork(
self.sphere_channels,
self.ffn_hidden_channels,
1,
self.lmax_list,
self.mmax_list,
self.SO3_grid,
self.ffn_activation,
self.use_gate_act,
self.use_grid_mlp,
self.use_sep_s2_act,
)
if self.regress_forces:
self.force_block = SO2EquivariantGraphAttention(
self.sphere_channels,
self.attn_hidden_channels,
self.num_heads,
self.attn_alpha_channels,
self.attn_value_channels,
1,
self.lmax_list,
self.mmax_list,
self.SO3_rotation,
self.mappingReduced,
self.SO3_grid,
self.max_num_elements,
self.edge_channels_list,
self.block_use_atom_edge_embedding,
self.use_m_share_rad,
self.attn_activation,
self.use_s2_act_attn,
self.use_attn_renorm,
self.use_gate_act,
self.use_sep_s2_act,
alpha_drop=0.0,
)

if self.load_energy_lin_ref:
self.energy_lin_ref = nn.Parameter(
torch.zeros(self.max_num_elements),
Expand All @@ -401,44 +369,8 @@ def __init__(
self.apply(self._init_weights)
self.apply(self._uniform_init_rad_func_linear_weights)

def _init_gp_partitions(
self,
atomic_numbers_full,
data_batch_full,
edge_index,
edge_distance,
edge_distance_vec,
):
"""Graph Parallel
This creates the required partial tensors for each rank given the full tensors.
The tensors are split on the dimension along the node index using node_partition.
"""
node_partition = gp_utils.scatter_to_model_parallel_region(
torch.arange(len(atomic_numbers_full)).to(self.device)
)
edge_partition = torch.where(
torch.logical_and(
edge_index[1] >= node_partition.min(),
edge_index[1] <= node_partition.max(), # TODO: 0 or 1?
)
)[0]
edge_index = edge_index[:, edge_partition]
edge_distance = edge_distance[edge_partition]
edge_distance_vec = edge_distance_vec[edge_partition]
atomic_numbers = atomic_numbers_full[node_partition]
data_batch = data_batch_full[node_partition]
node_offset = node_partition.min().item()
return (
atomic_numbers,
data_batch,
node_offset,
edge_index,
edge_distance,
edge_distance_vec,
)

@conditional_grad(torch.enable_grad())
def forward(self, data):
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.batch_size = len(data.natoms)
self.dtype = data.pos.dtype
self.device = data.pos.device
Expand Down Expand Up @@ -562,63 +494,43 @@ def forward(self, data):
# Final layer norm
x.embedding = self.norm(x.embedding)

###############################################################
# Energy estimation
###############################################################
node_energy = self.energy_block(x)
node_energy = node_energy.embedding.narrow(1, 0, 1)
if gp_utils.initialized():
node_energy = gp_utils.gather_from_model_parallel_region(node_energy, dim=0)
energy = torch.zeros(
len(data.natoms),
device=node_energy.device,
dtype=node_energy.dtype,
)
energy.index_add_(0, graph.batch_full, node_energy.view(-1))
energy = energy / self.avg_num_nodes

# Add the per-atom linear references to the energy.
if self.use_energy_lin_ref and self.load_energy_lin_ref:
# During training, target E = (E_DFT - E_ref - E_mean) / E_std, and
# during inference, \hat{E_DFT} = \hat{E} * E_std + E_ref + E_mean
# where
#
# E_DFT = raw DFT energy,
# E_ref = reference energy,
# E_mean = normalizer mean,
# E_std = normalizer std,
# \hat{E} = predicted energy,
# \hat{E_DFT} = predicted DFT energy.
#
# We can also write this as
# \hat{E_DFT} = E_std * (\hat{E} + E_ref / E_std) + E_mean,
# which is why we save E_ref / E_std as the linear reference.
with torch.cuda.amp.autocast(False):
energy = energy.to(self.energy_lin_ref.dtype).index_add(
0,
graph.batch_full,
self.energy_lin_ref[graph.atomic_numbers_full],
)
return {"node_embedding": x, "graph": graph}

outputs = {"energy": energy}
###############################################################
# Force estimation
###############################################################
if self.regress_forces:
forces = self.force_block(
x,
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
node_offset=graph.node_offset,
def _init_gp_partitions(
self,
atomic_numbers_full,
data_batch_full,
edge_index,
edge_distance,
edge_distance_vec,
):
"""Graph Parallel
This creates the required partial tensors for each rank given the full tensors.
The tensors are split on the dimension along the node index using node_partition.
"""
node_partition = gp_utils.scatter_to_model_parallel_region(
torch.arange(len(atomic_numbers_full)).to(self.device)
)
edge_partition = torch.where(
torch.logical_and(
edge_index[1] >= node_partition.min(),
edge_index[1] <= node_partition.max(), # TODO: 0 or 1?
)
forces = forces.embedding.narrow(1, 1, 3)
forces = forces.view(-1, 3).contiguous()
if gp_utils.initialized():
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
outputs["forces"] = forces

return outputs
)[0]
edge_index = edge_index[:, edge_partition]
edge_distance = edge_distance[edge_partition]
edge_distance_vec = edge_distance_vec[edge_partition]
atomic_numbers = atomic_numbers_full[node_partition]
data_batch = data_batch_full[node_partition]
node_offset = node_partition.min().item()
return (
atomic_numbers,
data_batch,
node_offset,
edge_index,
edge_distance,
edge_distance_vec,
)

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
Expand Down Expand Up @@ -682,142 +594,6 @@ def no_weight_decay(self) -> set:
return set(no_wd_list)


@registry.register_model("equiformer_v2_backbone")
class EquiformerV2Backbone(EquiformerV2, BackboneInterface):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO remove these once we deprecate/stop-inheriting EquiformerV2 class
self.energy_block = None
self.force_block = None

@conditional_grad(torch.enable_grad())
def forward(self, data: Batch) -> dict[str, torch.Tensor]:
self.batch_size = len(data.natoms)
self.dtype = data.pos.dtype
self.device = data.pos.device
atomic_numbers = data.atomic_numbers.long()
graph = self.generate_graph(
data,
enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly,
)

data_batch = data.batch
if gp_utils.initialized():
(
atomic_numbers,
data_batch,
node_offset,
edge_index,
edge_distance,
edge_distance_vec,
) = self._init_gp_partitions(
graph.atomic_numbers_full,
graph.batch_full,
graph.edge_index,
graph.edge_distance,
graph.edge_distance_vec,
)
graph.node_offset = node_offset
graph.edge_index = edge_index
graph.edge_distance = edge_distance
graph.edge_distance_vec = edge_distance_vec

###############################################################
# Entering Graph Parallel Region
# after this point, if using gp, then node, edge tensors are split
# across the graph parallel ranks, some full tensors such as
# atomic_numbers_full are required because we need to index into the
# full graph when computing edge embeddings or reducing nodes from neighbors
#
# all tensors that do not have the suffix "_full" refer to the partial tensors.
# if not using gp, the full values are equal to the partial values
# ie: atomic_numbers_full == atomic_numbers
###############################################################

###############################################################
# Initialize data structures
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
self.SO3_rotation[i].set_wigner(edge_rot_mat)

###############################################################
# Initialize node embeddings
###############################################################

# Init per node representations using an atomic number based embedding
x = SO3_Embedding(
len(atomic_numbers),
self.lmax_list,
self.sphere_channels,
self.device,
self.dtype,
)

offset_res = 0
offset = 0
# Initialize the l = 0, m = 0 coefficients for each resolution
for i in range(self.num_resolutions):
if self.num_resolutions == 1:
x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)
else:
x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[
:, offset : offset + self.sphere_channels
]
offset = offset + self.sphere_channels
offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2)

# Edge encoding (distance and atom edge)
graph.edge_distance = self.distance_expansion(graph.edge_distance)
if self.share_atom_edge_embedding and self.use_atom_edge_embedding:
source_element = graph.atomic_numbers_full[
graph.edge_index[0]
] # Source atom atomic number
target_element = graph.atomic_numbers_full[
graph.edge_index[1]
] # Target atom atomic number
source_embedding = self.source_embedding(source_element)
target_embedding = self.target_embedding(target_element)
graph.edge_distance = torch.cat(
(graph.edge_distance, source_embedding, target_embedding), dim=1
)

# Edge-degree embedding
edge_degree = self.edge_degree_embedding(
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
len(atomic_numbers),
graph.node_offset,
)
x.embedding = x.embedding + edge_degree.embedding

###############################################################
# Update spherical node embeddings
###############################################################

for i in range(self.num_layers):
x = self.blocks[i](
x, # SO3_Embedding
graph.atomic_numbers_full,
graph.edge_distance,
graph.edge_index,
batch=data_batch, # for GraphDropPath
node_offset=graph.node_offset,
)

# Final layer norm
x.embedding = self.norm(x.embedding)

return {"node_embedding": x, "graph": graph}


@registry.register_model("equiformer_v2_energy_head")
class EquiformerV2EnergyHead(nn.Module, HeadInterface):
def __init__(self, backbone):
Expand Down Expand Up @@ -897,3 +673,17 @@ def forward(self, data: Batch, emb: dict[str, torch.Tensor]):
if gp_utils.initialized():
forces = gp_utils.gather_from_model_parallel_region(forces, dim=0)
return {"forces": forces}


@registry.register_model("equiformer_v2_backbone_and_heads")
class EquiformerV2BackboneAndHeads(nn.Module):
def __init__(self, **kwargs):
super().__init__()
kwargs["model"] = "equiformer_v2_backbone"
heads = {"energy": {"module": "equiformer_v2_energy_head"}}
if "regress_forces" in kwargs and kwargs["regress_forces"]:
heads["forces"] = {"module": "equiformer_v2_force_head"}
self.model = HydraModel(backbone=kwargs, heads=heads)

def forward(self, data: Batch):
return self.model(data)
Loading

0 comments on commit 57763bc

Please sign in to comment.