diff --git a/src/fairchem/core/models/escn/escn_exportable.py b/src/fairchem/core/models/escn/escn_exportable.py index 97702873a..dfeb4f0b2 100644 --- a/src/fairchem/core/models/escn/escn_exportable.py +++ b/src/fairchem/core/models/escn/escn_exportable.py @@ -27,7 +27,6 @@ SO3_Rotation, rotation_to_wigner ) -from fairchem.core.models.escn.so3 import SO3_Embedding from fairchem.core.models.scn.sampling import CalcSpherePoints from fairchem.core.models.scn.smearing import ( GaussianSmearing, @@ -56,8 +55,8 @@ class eSCN(nn.Module, GraphModelMixin): max_num_elements (int): Maximum atomic number num_layers (int): Number of layers in the GNN - lmax_list (int): List of maximum degree of the spherical harmonics (1 to 10) - mmax_list (int): List of maximum order of the spherical harmonics (0 to lmax) + lmax (int): maximum degree of the spherical harmonics (1 to 10) + mmax (int): maximum order of the spherical harmonics (0 to lmax) sphere_channels (int): Number of spherical channels (one set per resolution) hidden_channels (int): Number of hidden units in message passing num_sphere_samples (int): Number of samples used to approximate the integration of the sphere in the output blocks @@ -78,8 +77,8 @@ def __init__( cutoff: float = 8.0, max_num_elements: int = 90, num_layers: int = 8, - lmax_list: list[int] | None = None, - mmax_list: list[int] | None = None, + lmax_list: List[int] = [4], # list of 1, for backward compat only right now, + mmax_list: List[int] = [2], # list of 1, for backward compat only right now, sphere_channels: int = 128, hidden_channels: int = 256, edge_channels: int = 128, @@ -90,10 +89,6 @@ def __init__( show_timing_info: bool = False, resolution: int | None = None, ) -> None: - if mmax_list is None: - mmax_list = [2] - if lmax_list is None: - lmax_list = [6] super().__init__() import sys @@ -120,8 +115,9 @@ def __init__( self.grad_forces = False self.lmax_list = lmax_list self.mmax_list = mmax_list - self.num_resolutions: int = len(self.lmax_list) - self.sphere_channels_all: int = self.num_resolutions * self.sphere_channels + assert len(self.lmax_list) == 1 and len(self.mmax_list) == 1 + self.lmax = lmax_list[0] + self.mmax = mmax_list[0] self.basis_width_scalar = basis_width_scalar self.distance_function = distance_function @@ -133,7 +129,7 @@ def __init__( # Weights for message initialization self.sphere_embedding = nn.Embedding( - self.max_num_elements, self.sphere_channels_all + self.max_num_elements, self.sphere_channels ) # Initialize the function used to measure the distances between atoms @@ -174,11 +170,10 @@ def __init__( ) # Initialize the transformations between spherical and grid representations - assert self.num_resolutions == 1, "Only one resolution is supported" self.SO3_grid = nn.ModuleDict() - self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax_list[0], self.lmax_list[0], resolution=resolution) - self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax_list[0], self.mmax_list[0], resolution=resolution) - self.mappingReduced = CoefficientMapping(self.lmax_list, self.mmax_list) + self.SO3_grid["lmax_lmax"] = SO3_Grid(self.lmax, self.lmax, resolution=resolution) + self.SO3_grid["lmax_mmax"] = SO3_Grid(self.lmax, self.mmax, resolution=resolution) + self.mappingReduced = CoefficientMapping([self.lmax], [self.mmax]) # Initialize the blocks for each layer of the GNN self.layer_blocks = nn.ModuleList() @@ -188,8 +183,8 @@ def __init__( self.sphere_channels, self.hidden_channels, self.edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.distance_expansion, self.max_num_elements, self.SO3_grid, @@ -200,11 +195,11 @@ def __init__( # Output blocks for energy and forces self.energy_block = EnergyBlock( - self.sphere_channels_all, self.num_sphere_samples, self.act + self.sphere_channels, self.num_sphere_samples, self.act ) if self.regress_forces: self.force_block = ForceBlock( - self.sphere_channels_all, self.num_sphere_samples, self.act + self.sphere_channels, self.num_sphere_samples, self.act ) # Create a roughly evenly distributed point sampling of the sphere for the output blocks @@ -213,19 +208,14 @@ def __init__( ) # For each spherical point, compute the spherical harmonic coefficient weights - sphharm_weights: list[nn.Parameter] = [] - for i in range(self.num_resolutions): - sphharm_weights.append( - nn.Parameter( - o3.spherical_harmonics( - torch.arange(0, self.lmax_list[i] + 1).tolist(), - self.sphere_points, - False, - ), - requires_grad=False, - ) - ) - self.sphharm_weights = nn.ParameterList(sphharm_weights) + self.sphharm_weights: nn.Parameter = nn.Parameter( + o3.spherical_harmonics( + torch.arange(0, self.lmax + 1).tolist(), + self.sphere_points, + False, + ), + requires_grad=False, + ) @conditional_grad(torch.enable_grad()) @@ -250,36 +240,25 @@ def forward(self, data): edge_rot_mat = self._init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) - wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax_list[0]).detach() + wigner = rotation_to_wigner(edge_rot_mat, 0, self.lmax).detach() ############################################################### # Initialize node embeddings ############################################################### # Init per node representations using an atomic number based embedding - offset = 0 - x = SO3_Embedding( + x_message = torch.zeros( num_atoms, - self.lmax_list, + int((self.lmax + 1) ** 2), self.sphere_channels, - device, - self.dtype, + device=device, + dtype=self.dtype, ) - - offset_res = 0 - offset = 0 - # Initialize the l=0,m=0 coefficients for each resolution - for i in range(self.num_resolutions): - x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ - :, offset : offset + self.sphere_channels - ] - offset = offset + self.sphere_channels - offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + x_message[:, 0, :] = self.sphere_embedding(atomic_numbers) ############################################################### # Update spherical node embeddings ############################################################### - x_message = x.embedding for i in range(self.num_layers): if i > 0: x_message_new = self.layer_blocks[i]( @@ -302,29 +281,10 @@ def forward(self, data): graph.edge_index, wigner, ) - x.embedding = x_message # Sample the spherical channels (node embeddings) at evenly distributed points on the sphere. # These values are fed into the output blocks. - x_pt = torch.tensor([], device=device) - offset = 0 - # Compute the embedding values at every sampled point on the sphere - for i in range(self.num_resolutions): - num_coefficients = int((x.lmax_list[i] + 1) ** 2) - x_pt = torch.cat( - [ - x_pt, - torch.einsum( - "abc, pb->apc", - x.embedding[:, offset : offset + num_coefficients], - self.sphharm_weights[i], - ).contiguous(), - ], - dim=2, - ) - offset = offset + num_coefficients - - x_pt = x_pt.view(-1, self.sphere_channels_all) + x_pt = torch.einsum("abc, pb->apc", x_message, self.sphharm_weights).contiguous() ############################################################### # Energy estimation @@ -423,8 +383,8 @@ class LayerBlock(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int) degrees (l) for each resolution + mmax (int): orders (m) for each resolution distance_expansion (func): Function used to compute distance embedding max_num_elements (int): Maximum number of atomic numbers SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations @@ -437,8 +397,8 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, distance_expansion, max_num_elements: int, SO3_grid: SO3_Grid, @@ -448,11 +408,9 @@ def __init__( super().__init__() self.layer_idx = layer_idx self.act = act - self.lmax_list = lmax_list - self.mmax_list = mmax_list - self.num_resolutions = len(lmax_list) + self.lmax = lmax + self.mmax = mmax self.sphere_channels = sphere_channels - self.sphere_channels_all = self.num_resolutions * self.sphere_channels self.SO3_grid = SO3_grid self.mappingReduced = mappingReduced @@ -462,8 +420,8 @@ def __init__( self.sphere_channels, hidden_channels, edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, distance_expansion, max_num_elements, self.SO3_grid, @@ -473,15 +431,15 @@ def __init__( # Non-linear point-wise comvolution for the aggregated messages self.fc1_sphere = nn.Linear( - 2 * self.sphere_channels_all, self.sphere_channels_all, bias=False + 2 * self.sphere_channels, self.sphere_channels, bias=False ) self.fc2_sphere = nn.Linear( - self.sphere_channels_all, self.sphere_channels_all, bias=False + self.sphere_channels, self.sphere_channels, bias=False ) self.fc3_sphere = nn.Linear( - self.sphere_channels_all, self.sphere_channels_all, bias=False + self.sphere_channels, self.sphere_channels, bias=False ) def forward( @@ -505,7 +463,7 @@ def forward( # Project to grid # x_grid_message = x_message.to_grid(self.SO3_grid["lmax_lmax"]) - to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + to_grid_mat = self.SO3_grid["lmax_lmax"].to_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] x_grid_message = torch.einsum("bai,zic->zbac", to_grid_mat, x_message) # x_grid = x.to_grid(self.SO3_grid["lmax_lmax"]) @@ -519,7 +477,7 @@ def forward( # Project back to spherical harmonic coefficients # x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"]) - from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])] + from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax, self.lmax)] x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Return aggregated messages @@ -535,8 +493,8 @@ class MessageBlock(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution distance_expansion (func): Function used to compute distance embedding max_num_elements (int): Maximum number of atomic numbers SO3_grid (SO3_grid): Class used to convert from grid the spherical harmonic representations @@ -549,8 +507,8 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, distance_expansion, max_num_elements: int, SO3_grid: SO3_Grid, @@ -563,12 +521,11 @@ def __init__( self.hidden_channels = hidden_channels self.sphere_channels = sphere_channels self.SO3_grid = SO3_grid - self.num_resolutions = len(lmax_list) - self.lmax_list = lmax_list - self.mmax_list = mmax_list + self.lmax = lmax + self.mmax = mmax self.edge_channels = edge_channels self.mappingReduced = mappingReduced - self.out_mask = self.mappingReduced.coefficient_idx(self.lmax_list[0], self.mmax_list[0]) + self.out_mask = self.mappingReduced.coefficient_idx(self.lmax, self.mmax) # Create edge scalar (invariant to rotations) features self.edge_block = EdgeBlock( @@ -583,8 +540,8 @@ def __init__( self.sphere_channels, self.hidden_channels, self.edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.act, self.mappingReduced ) @@ -592,8 +549,8 @@ def __init__( self.sphere_channels, self.hidden_channels, self.edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.act, self.mappingReduced ) @@ -626,8 +583,6 @@ def forward( # Rotate the irreps to align with the edge x_source = torch.bmm(wigner[:, self.out_mask, :], x_source) x_target = torch.bmm(wigner[:, self.out_mask, :], x_target) - # x_source = SO3_edge_rot.rotate(x_source, self.lmax_list[0], self.mmax_list[0]) - # x_target = SO3_edge_rot.rotate(x_target, self.lmax_list[0], self.mmax_list[0]) # Compute messages x_source = self.so2_block_source(x_source, x_edge) @@ -644,14 +599,12 @@ def forward( x_target = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid) # Rotate back the irreps - # x_target = SO3_edge_rot.rotate_inv(x_target, self.lmax_list[0], self.mmax_list[0]) wigner_inv = torch.transpose(wigner, 1, 2).contiguous() x_target = torch.bmm(wigner_inv[:, :, self.out_mask], x_target) # Compute the sum of the incoming neighboring messages for each target node new_embedding = torch.fill(x.clone(), 0) new_embedding.index_add_(0, edge_index[1], x_target) - # x_target._reduce_edge(edge_index[1], len(x.embedding)) return new_embedding @@ -664,8 +617,8 @@ class SO2Block(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution act (function): Non-linear activation function """ @@ -674,24 +627,20 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, act, mappingReduced ) -> None: super().__init__() self.sphere_channels = sphere_channels self.hidden_channels = hidden_channels - self.lmax_list = lmax_list - self.mmax_list = mmax_list - self.num_resolutions: int = len(lmax_list) + self.lmax = lmax + self.mmax = mmax self.act = act self.mappingReduced = mappingReduced - num_channels_m0 = 0 - for i in range(self.num_resolutions): - num_coefficents = self.lmax_list[i] + 1 - num_channels_m0 = num_channels_m0 + num_coefficents * self.sphere_channels + num_channels_m0 = (self.lmax + 1) * self.sphere_channels # SO(2) convolution for m=0 self.fc1_dist0 = nn.Linear(edge_channels, self.hidden_channels) @@ -700,14 +649,14 @@ def __init__( # SO(2) convolution for non-zero m self.so2_conv = nn.ModuleList() - for m in range(1, max(self.mmax_list) + 1): + for m in range(1, self.mmax + 1): so2_conv = SO2Conv( m, self.sphere_channels, self.hidden_channels, edge_channels, - self.lmax_list, - self.mmax_list, + self.lmax, + self.mmax, self.act, ) self.so2_conv.append(so2_conv) @@ -740,7 +689,7 @@ def forward( # Compute the values for the m > 0 coefficients offset = self.mappingReduced.m_size[0] - for m in range(1, max(self.mmax_list) + 1): + for m in range(1, self.mmax + 1): # Get the m order coefficients x_m = x[ :, offset : offset + 2 * self.mappingReduced.m_size[m] @@ -769,8 +718,8 @@ class SO2Conv(torch.nn.Module): sphere_channels (int): Number of spherical channels hidden_channels (int): Number of hidden channels used during the SO(2) conv edge_channels (int): Size of invariant edge embedding - lmax_list (list:int): List of degrees (l) for each resolution - mmax_list (list:int): List of orders (m) for each resolution + lmax (int): degrees (l) for each resolution + mmax (int): orders (m) for each resolution act (function): Non-linear activation function """ @@ -780,26 +729,23 @@ def __init__( sphere_channels: int, hidden_channels: int, edge_channels: int, - lmax_list: list[int], - mmax_list: list[int], + lmax: int, + mmax: int, act, ) -> None: super().__init__() self.hidden_channels = hidden_channels - self.lmax_list = lmax_list - self.mmax_list = mmax_list + self.lmax = lmax + self.mmax = mmax self.sphere_channels = sphere_channels - self.num_resolutions: int = len(self.lmax_list) self.m = m self.act = act - num_channels = 0 - for i in range(self.num_resolutions): - num_coefficents = 0 - if self.mmax_list[i] >= m: - num_coefficents = self.lmax_list[i] - m + 1 + num_coefficents = 0 + if self.mmax >= m: + num_coefficents = self.lmax - m + 1 - num_channels = num_channels + num_coefficents * self.sphere_channels + num_channels = num_coefficents * self.sphere_channels assert num_channels > 0 diff --git a/tests/core/models/test_escn_compiles.py b/tests/core/models/test_escn_compiles.py index 8e53eb702..e653ef405 100644 --- a/tests/core/models/test_escn_compiles.py +++ b/tests/core/models/test_escn_compiles.py @@ -161,8 +161,8 @@ def test_escn_so2_conv_exports_and_compiles(self, tol=1e-5) -> None: sphere_channels=shpere_channels, hidden_channels=128, edge_channels=edge_channels, - lmax_list=[lmax], - mmax_list=[mmax], + lmax=lmax, + mmax=mmax, act=torch.nn.SiLU(), mappingReduced=mappingReduced ) @@ -192,8 +192,8 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None: sphere_channels = sphere_channels, hidden_channels = hidden_channels, edge_channels = edge_channels, - lmax_list = [lmax], - mmax_list = [mmax], + lmax = lmax, + mmax = mmax, distance_expansion = distance_expansion, max_num_elements = 90, SO3_grid = SO3_grid, @@ -253,8 +253,8 @@ def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None: sphere_channels = sphere_channels, hidden_channels = hidden_channels, edge_channels = edge_channels, - lmax_list = [lmax], - mmax_list = [mmax], + lmax = lmax, + mmax = mmax, distance_expansion = distance_expansion, max_num_elements = 90, SO3_grid = SO3_grid, @@ -320,7 +320,6 @@ def test_escn_compiles(self): # os.environ["TORCHDYNAMO_REPRO_AFTER"]="dynamo" # torch._dynamo.config.verbose = True compiled_model = torch.compile(model, dynamic=True) - torch._dynamo.config.optimize_ddp = False # torch._dynamo.explain(model)(data) # assert False # torch._dynamo.reset()