Skip to content

Commit

Permalink
message block compiles and exports
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Sep 2, 2024
1 parent 7f90152 commit f56b440
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 60 deletions.
34 changes: 15 additions & 19 deletions src/fairchem/core/models/escn/escn_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CoefficientMapping,
SO3_Grid,
SO3_Rotation,
rotation_to_wigner
)
from fairchem.core.models.escn.so3 import SO3_Embedding
from fairchem.core.models.scn.sampling import CalcSpherePoints
Expand Down Expand Up @@ -177,14 +178,6 @@ def __init__(
self.SO3_grid = nn.ModuleDict()
self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution)
self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution)
# self.SO3_grid = nn.ModuleList()
# for lval in range(max(self.lmax_list) + 1):
# SO3_m_grid = nn.ModuleList()
# for m in range(max(self.lmax_list) + 1):
# SO3_m_grid.append(SO3_Grid(lval, m, resolution=resolution))

# self.SO3_grid.append(SO3_m_grid)
# import pdb;pdb.set_trace()
self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list)

# Initialize the blocks for each layer of the GNN
Expand Down Expand Up @@ -257,9 +250,7 @@ def forward(self, data):
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = SO3_Rotation(edge_rot_mat, self.lmax_list[0])
wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax_list[0]).detach()

###############################################################
# Initialize node embeddings
Expand Down Expand Up @@ -296,7 +287,7 @@ def forward(self, data):
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
wigner,
)

# Residual layer for all layers past the first
Expand All @@ -309,7 +300,7 @@ def forward(self, data):
atomic_numbers,
graph.edge_distance,
graph.edge_index,
self.SO3_edge_rot,
wigner,
)
x.embedding = x_message

Expand Down Expand Up @@ -499,15 +490,15 @@ def forward(
atomic_numbers: torch.Tensor,
edge_distance: torch.Tensor,
edge_index: torch.Tensor,
SO3_edge_rot: SO3_Rotation,
wigner: torch.Tensor,
) -> torch.Tensor:
# Compute messages by performing message block
x_message = self.message_block(
x,
atomic_numbers,
edge_distance,
edge_index,
SO3_edge_rot,
wigner,
)
print(f"x_message: {x_message.mean()}")

Expand Down Expand Up @@ -580,6 +571,7 @@ def __init__(
self.mmax_list = mmax_list
self.edge_channels = edge_channels
self.mappingReduced = mappingReduced
self.out_mask = self.mappingReduced.coefficient_idx(self.lmax_list[0], self.mmax_list[0])

# Create edge scalar (invariant to rotations) features
self.edge_block = EdgeBlock(
Expand Down Expand Up @@ -615,7 +607,7 @@ def forward(
atomic_numbers: torch.Tensor,
edge_distance: torch.Tensor,
edge_index: torch.Tensor,
SO3_edge_rot: SO3_Rotation,
wigner: torch.Tensor,
) -> torch.Tensor:
###############################################################
# Compute messages
Expand All @@ -635,8 +627,10 @@ def forward(
x_target = x_target[edge_index[1, :]]

# Rotate the irreps to align with the edge
x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0])
x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0])
x_source = torch.bmm(wigner[:, self.out_mask, :], x_source)
x_target = torch.bmm(wigner[:, self.out_mask, :], x_target)
# x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0])
# x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0])

# Compute messages
x_source = self.so2_block_source(x_source, x_edge)
Expand All @@ -653,7 +647,9 @@ def forward(
x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)

# Rotate back the irreps
x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0])
# x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0])
wigner_inv = torch.transpose(wigner, 1, 2).contiguous()
x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target)

# Compute the sum of the incoming neighboring messages for each target node
new_embedding = torch.fill(x.clone(), 0)
Expand Down
59 changes: 27 additions & 32 deletions src/fairchem/core/models/utils/so3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def _z_rot_mat(angle: torch.Tensor, lv: int) -> torch.Tensor:
M[..., inds, inds] = torch.cos(frequencies * angle[..., None])
return M

def rotation_to_wigner(
edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int
) -> torch.Tensor:
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
alpha, beta = o3.xyz_to_angles(x)
R = (
o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2)
@ edge_rot_mat
)
gamma = torch.atan2(R[..., 0, 2], R[..., 0, 0])

size = int((end_lmax + 1) ** 2) - int((start_lmax) ** 2)
wigner = torch.zeros(len(alpha), size, size, device=edge_rot_mat.device)
start = 0
for lmax in range(start_lmax, end_lmax + 1):
block = wigner_D(lmax, alpha, beta, gamma)
end = start + block.size()[1]
wigner[:, start:end, start:end] = block
start = end

return wigner.detach()


class CoefficientMapping(torch.nn.Module):
"""
Expand All @@ -63,13 +85,11 @@ def __init__(
self,
lmax_list,
mmax_list,
use_rotate_inv_rescale=False
):
super().__init__()

self.lmax_list = lmax_list
self.mmax_list = mmax_list
self.use_rotate_inv_rescale = use_rotate_inv_rescale
self.num_resolutions = len(lmax_list)

assert (len(self.lmax_list) == 1) and (len(self.mmax_list) == 1)
Expand Down Expand Up @@ -121,13 +141,9 @@ def __init__(
self.register_buffer('l_harmonic', l_harmonic)
self.register_buffer('m_harmonic', m_harmonic)
self.register_buffer('m_complex', m_complex)
# self.register_buffer('res_size', res_size)
self.register_buffer('to_m', to_m)
# self.register_buffer('m_size', m_size)

self.pre_compute_coefficient_idx()
if self.use_rotate_inv_rescale:
self.pre_compute_rotate_inv_rescale()


# Return mask containing coefficients of order m (real and imaginary parts)
Expand Down Expand Up @@ -215,30 +231,11 @@ def pre_compute_rotate_inv_rescale(self):
rotate_inv_rescale[:, start_idx : (start_idx + length), start_idx : (start_idx + length)] = rescale_factor
rotate_inv_rescale = rotate_inv_rescale[:, :, mask_indices]
self.register_buffer('rotate_inv_rescale_l{}_m{}'.format(l, m), rotate_inv_rescale)


def prepare_rotate_inv_rescale(self):
lmax = max(self.lmax_list)
rotate_inv_rescale_list = []
for l in range(lmax + 1):
l_list = []
for m in range(lmax + 1):
l_list.append(getattr(self, 'rotate_inv_rescale_l{}_m{}'.format(l, m), None))
rotate_inv_rescale_list.append(l_list)
return rotate_inv_rescale_list


# Return the re-scaling for rotating back to original frame
# this is required since we only use a subset of m components for SO(2) convolution
def get_rotate_inv_rescale(self, lmax, mmax):
temp = self.prepare_rotate_inv_rescale()
return temp[lmax][mmax]


def __repr__(self):
return f"{self.__class__.__name__}(lmax_list={self.lmax_list}, mmax_list={self.mmax_list})"

class SO3_Rotation(torch.nn.Module):
class SO3_Rotation:
"""
Helper functions for Wigner-D rotations
Expand All @@ -262,12 +259,10 @@ def __init__(
self.wigner = self.wigner.detach()
self.wigner_inv = self.wigner_inv.detach()

self.set_lmax(lmax)

# Initialize coefficients for reshape l<-->m
def set_lmax(self, lmax) -> None:
self.lmax = lmax
self.mapping = CoefficientMapping([self.lmax], [self.lmax], use_rotate_inv_rescale=True)
import pdb;pdb.set_trace()
self.mapping = CoefficientMapping([self.lmax], [self.lmax])


# Rotate the embedding
def rotate(self, embedding, out_lmax, out_mmax) -> torch.Tensor:
Expand All @@ -286,7 +281,7 @@ def rotate_inv(self, embedding, in_lmax, in_mmax) -> torch.Tensor:
def RotationToWignerDMatrix(
self, edge_rot_mat: torch.Tensor, start_lmax: int, end_lmax: int
) -> torch.Tensor:
x = edge_rot_mat @ edge_rot_mat.new_tensor([0.0, 1.0, 0.0])
x = edge_rot_mat[:,:,1]
alpha, beta = o3.xyz_to_angles(x)
R = (
o3.angles_to_matrix(alpha, beta, torch.zeros_like(alpha)).transpose(-1, -2)
Expand Down
28 changes: 19 additions & 9 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from fairchem.core.models.scn.smearing import GaussianSmearing
from fairchem.core.models.base import GraphModelMixin

from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, SO3_Rotation
from fairchem.core.models.utils.so3_utils import CoefficientMapping, SO3_Grid, rotation_to_wigner
from fairchem.core.models.escn import escn_exportable

from torch.export import export
Expand Down Expand Up @@ -175,7 +175,7 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None:
compiled_out = compiled_model(*args)
assert torch.allclose(compiled_out, regular_out, atol=tol)

def test_escn_message_block_exports_and_compiles(self) -> None:
def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None:
random.seed(1)

sphere_channels = 128
Expand All @@ -186,7 +186,7 @@ def test_escn_message_block_exports_and_compiles(self) -> None:
SO3_grid = torch.nn.ModuleDict()
SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax)
SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax)
mappingReduced = escn_exportable.CoefficientMapping([lmax], [mmax])
mappingReduced = CoefficientMapping([lmax], [mmax])
message_block = escn_exportable.MessageBlock(
layer_idx = 0,
sphere_channels = sphere_channels,
Expand All @@ -208,7 +208,7 @@ def test_escn_message_block_exports_and_compiles(self) -> None:
edge_rot_mat = full_model._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
SO3_edge_rot = SO3_Rotation(edge_rot_mat, lmax)
wigner = rotation_to_wigner(edge_rot_mat, 0, lmax).detach()

# generate inputs
batch_sizes = [34]
Expand All @@ -220,15 +220,25 @@ def test_escn_message_block_exports_and_compiles(self) -> None:
atom_n = torch.randint(1, 90, (b,))
edge_d = torch.rand([num_edges])
edge_indx = torch.randint(0, b, (2, num_edges))
args.append((x, atom_n, edge_d, edge_indx, SO3_edge_rot))
args.append((x, atom_n, edge_d, edge_indx, wigner))

torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.verbose = True
# torch._logging.set_logs(dynamo = logging.INFO)
# torch._dynamo.reset()
# explain_output = torch._dynamo.explain(message_block)(*args[0])
# print(explain_output)
compiled_model = torch.compile(message_block, dynamic=True)
compiled_output = compiled_model(*args[0])

# compiled_model = torch.compile(message_block, dynamic=True)
exported_prog = export(message_block, args=args[0])
exported_output = exported_prog(*args[0])

# output = message_block(*args)
# compiled_output = compiled_model(*args)

regular_out = message_block(*args[0])
assert torch.allclose(compiled_output, regular_out, atol=tol)
assert torch.allclose(exported_output, regular_out, atol=tol)

def test_escn_compiles(self):
init("gloo")
Expand Down

0 comments on commit f56b440

Please sign in to comment.