Skip to content

Commit

Permalink
separate eqv2 compile code
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Sep 2, 2024
1 parent 94e4a7f commit dabf29f
Show file tree
Hide file tree
Showing 5 changed files with 478 additions and 97 deletions.
11 changes: 11 additions & 0 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch.distributed as dist

from fairchem.core.common.typing import none_throws
from torch.distributed.elastic.utils.distributed import get_free_port

T = TypeVar("T")

Expand Down Expand Up @@ -192,3 +193,13 @@ def gather_objects(data: T, group: dist.ProcessGroup = dist.group.WORLD) -> list
output = [None for _ in range(get_world_size())] if is_master() else None
dist.gather_object(data, output, group=group, dst=0)
return output

def init_local_distributed_process_group(backend="nccl"):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
dist.init_process_group(
rank=0,
world_size=1,
backend=backend,
timeout=timedelta(seconds=10), # setting up timeout for distributed collectives
)
8 changes: 4 additions & 4 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ def radius_graph_pbc(
atom_distance=atom_distance_sqr,
max_num_neighbors_threshold=max_num_neighbors_threshold,
enforce_max_strictly=enforce_max_neighbors_strictly,
batch=data.batch,
)

if not torch.all(mask_num_neighbors):
Expand All @@ -786,6 +787,7 @@ def get_max_neighbors_mask(
max_num_neighbors_threshold,
degeneracy_tolerance: float = 0.01,
enforce_max_strictly: bool = False,
batch=None,
):
"""
Give a mask that filters out edges so that each atom has at most
Expand All @@ -808,14 +810,12 @@ def get_max_neighbors_mask(
# Get number of neighbors
# segment_coo assumes sorted index
ones = index.new_ones(1).expand_as(index)
num_neighbors = segment_coo(ones, index, dim_size=num_atoms)
num_neighbors = scatter(ones, index, dim_size=num_atoms)
max_num_neighbors = num_neighbors.max()
num_neighbors_thresholded = num_neighbors.clamp(max=max_num_neighbors_threshold)

# Get number of (thresholded) neighbors per image
image_indptr = torch.zeros(natoms.shape[0] + 1, device=device, dtype=torch.long)
image_indptr[1:] = torch.cumsum(natoms, dim=0)
num_neighbors_image = segment_csr(num_neighbors_thresholded, image_indptr)
num_neighbors_image = scatter(num_neighbors_thresholded, batch, dim_size=natoms.shape[0])

# If max_num_neighbors is below the threshold, return early
if (
Expand Down
140 changes: 140 additions & 0 deletions src/fairchem/core/models/equiformer_v2/so2_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
self.rad_func = RadialFunction(self.edge_channels_list)

def forward(self, x, x_edge):

num_edges = len(x_edge)
out = []

Expand Down Expand Up @@ -352,3 +353,142 @@ def forward(self, x, x_edge):
out_embedding._l_primary(self.mappingReduced)

return out_embedding

class SO2_Convolution_Exportable(torch.nn.Module):
"""
SO(2) Block: Perform SO(2) convolutions for all m (orders)
Args:
sphere_channels (int): Number of spherical channels
m_output_channels (int): Number of output channels used during the SO(2) conv
lmax_list (list:int): List of degrees (l) for each resolution
mmax_list (list:int): List of orders (m) for each resolution
mappingReduced (CoefficientMappingModule): Used to extract a subset of m components
internal_weights (bool): If True, not using radial function to multiply inputs features
edge_channels_list (list:int): List of sizes of invariant edge embedding. For example, [input_channels, hidden_channels, hidden_channels].
extra_m0_output_channels (int): If not None, return `out_embedding` (SO3_Embedding) and `extra_m0_features` (Tensor).
"""

def __init__(
self,
sphere_channels: int,
m_output_channels: int,
lmax_list: list[int],
mmax_list: list[int],
mappingReduced,
internal_weights: bool = True,
edge_channels_list: list[int] | None = None,
extra_m0_output_channels: int | None = None,
):
super().__init__()
self.sphere_channels = sphere_channels
self.m_output_channels = m_output_channels
self.lmax_list = lmax_list
self.mmax_list = mmax_list
self.mappingReduced = mappingReduced
self.num_resolutions = len(lmax_list)
self.internal_weights = internal_weights
self.edge_channels_list = copy.deepcopy(edge_channels_list)
self.extra_m0_output_channels = extra_m0_output_channels

num_channels_rad = 0 # for radial function

num_channels_m0 = 0
for i in range(self.num_resolutions):
num_coefficients = self.lmax_list[i] + 1
num_channels_m0 = num_channels_m0 + num_coefficients * self.sphere_channels

# SO(2) convolution for m = 0
m0_output_channels = self.m_output_channels * (
num_channels_m0 // self.sphere_channels
)
if self.extra_m0_output_channels is not None:
m0_output_channels = m0_output_channels + self.extra_m0_output_channels
self.fc_m0 = Linear(num_channels_m0, m0_output_channels)
num_channels_rad = num_channels_rad + self.fc_m0.in_features

# SO(2) convolution for non-zero m
self.so2_m_conv = nn.ModuleList()
for m in range(1, max(self.mmax_list) + 1):
self.so2_m_conv.append(
SO2_m_Convolution(
m,
self.sphere_channels,
self.m_output_channels,
self.lmax_list,
self.mmax_list,
)
)
num_channels_rad = num_channels_rad + self.so2_m_conv[-1].fc.in_features

# Embedding function of distance
self.rad_func = None
if not self.internal_weights:
assert self.edge_channels_list is not None
self.edge_channels_list.append(int(num_channels_rad))
self.rad_func = RadialFunction(self.edge_channels_list)

def forward(self, x_emb, x_edge):
# x_emb: [num_edges, num_sh_coefs, num_features]
# x_edge: [num_edges, num_edge_features]

num_edges = x_edge.shape[0]
out = []
# torch export does not inputs based on a buffered tensor
m_size = self.mappingReduced.m_size

# Reshape the spherical harmonics based on m (order), equivalent to x._m_primary
x_emb = torch.einsum("nac, ba -> nbc", x_emb, self.mappingReduced.to_m)

# radial function
if self.rad_func is not None:
x_edge = self.rad_func(x_edge)
offset_rad = 0

# Compute m=0 coefficients separately since they only have real values (no imaginary)
x_0 = x_emb.narrow(1, 0, m_size[0])
x_0 = x_0.reshape(x_edge.shape[0], -1)
if self.rad_func is not None:
x_edge_0 = x_edge.narrow(1, 0, self.fc_m0.in_features)
x_0 = x_0 * x_edge_0
x_0 = self.fc_m0(x_0)

x_0_extra = None
# extract extra m0 features
if self.extra_m0_output_channels is not None:
x_0_extra = x_0.narrow(-1, 0, self.extra_m0_output_channels)
x_0 = x_0.narrow(
-1,
self.extra_m0_output_channels,
(self.fc_m0.out_features - self.extra_m0_output_channels),
)

x_0 = x_0.view(num_edges, -1, self.m_output_channels)
out.append(x_0)
offset_rad = offset_rad + self.fc_m0.in_features

# Compute the values for the m > 0 coefficients
offset = m_size[0]
for m in range(1, max(self.mmax_list) + 1):
# Get the m order coefficients
x_m = x_emb.narrow(1, offset, 2 * m_size[m])
x_m = x_m.reshape(num_edges, 2, -1)

# Perform SO(2) convolution
if self.rad_func is not None:
x_edge_m = x_edge.narrow(
1, offset_rad, self.so2_m_conv[m - 1].fc.in_features
)
x_edge_m = x_edge_m.reshape(
num_edges, 1, self.so2_m_conv[m - 1].fc.in_features
)
x_m = x_m * x_edge_m
x_m = self.so2_m_conv[m - 1](x_m)
x_m = x_m.view(num_edges, -1, self.m_output_channels)
out.append(x_m)
offset = offset + 2 * m_size[m]
offset_rad = offset_rad + self.so2_m_conv[m - 1].fc.in_features

out = torch.cat(out, dim=1)
out = torch.einsum("nac, ab -> nbc", out, self.mappingReduced.to_m)
return out
Loading

0 comments on commit dabf29f

Please sign in to comment.