From f409ac7195baac092fb7bf0777b2b51f508967d9 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 25 Oct 2024 17:07:44 +0000 Subject: [PATCH 1/3] add option for new edge rot mat --- src/fairchem/core/models/escn/edge_rot_mat.py | 123 ++++++++++++++++++ src/fairchem/core/models/escn/escn.py | 68 ++-------- 2 files changed, 134 insertions(+), 57 deletions(-) create mode 100644 src/fairchem/core/models/escn/edge_rot_mat.py diff --git a/src/fairchem/core/models/escn/edge_rot_mat.py b/src/fairchem/core/models/escn/edge_rot_mat.py new file mode 100644 index 000000000..2c4b4c6c4 --- /dev/null +++ b/src/fairchem/core/models/escn/edge_rot_mat.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +import logging +import math + +import torch + + +# Algorithm from Ken Whatmough (https://math.stackexchange.com/users/918128/ken-whatmough) +def vec3_to_perp_vec3(v): + """ + Small proof: + input = x y z + output = s(x)|z| s(y)|z| -s(z)(|x|+|y|) + + input dot output + = x*s(x)*|z| + y*s(y)*|z| - z*s(z)*|x| - z*s(z)*|y| + a*s(a)=|a| , + = |x|*|z| + |y|*|z| - |z|*|x| - |z|*|y| = 0 + + """ + return torch.hstack( + [ + v[:, [2]].copysign(v[:, [0, 1]]), + -v[:, [0, 1]].copysign(v[:, [2]]).sum(axis=1, keepdim=True), + ] + ) + + +# https://en.wikipedia.org/wiki/Rodrigues'_rotation_formula#Matrix_notation +def vec3_rotate_around_axis(v, axis, thetas): + # v_rot= v + (sTheta)*(axis X v) + (1-cTheta)*(axis X (axis X v)) + Kv = torch.cross(axis, v, dim=1) + KKv = torch.cross(axis, Kv, dim=1) + s_theta = torch.sin(thetas) + c_theta = torch.cos(thetas) + return v + s_theta * Kv + (1 - c_theta) * KKv + + +def init_edge_rot_mat_new(data, edge_index, edge_distance_vec): + edge_vec_0 = edge_distance_vec.detach() + edge_vec_0_distance = torch.linalg.norm(edge_vec_0, axis=1, keepdim=True) + + # Make sure the atoms are far enough apart + # assert torch.min(edge_vec_0_distance) < 0.0001 + if torch.min(edge_vec_0_distance) < 0.0001: + logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") + + norm_x = edge_vec_0 / edge_vec_0_distance + + perp_to_norm_x = vec3_to_perp_vec3(norm_x) + random_rotated_in_plane_perp_to_norm_x = vec3_rotate_around_axis( + perp_to_norm_x, + norm_x, + torch.rand((norm_x.shape[0], 1), device=norm_x.device) * 2 * math.pi, + ) + + norm_z = random_rotated_in_plane_perp_to_norm_x / torch.linalg.norm( + random_rotated_in_plane_perp_to_norm_x, axis=1, keepdim=True + ) + + norm_y = torch.cross(norm_x, norm_z, dim=1) + norm_y /= torch.linalg.norm(norm_y, dim=1, keepdim=True) + + # Construct the 3D rotation matrix + norm_x = norm_x.view(-1, 1, 3) + norm_y = -norm_y.view(-1, 1, 3) + norm_z = norm_z.view(-1, 1, 3) + return torch.cat([norm_z, norm_x, norm_y], dim=1).contiguous() + + +# Initialize the edge rotation matrics +def init_edge_rot_mat_og(data, edge_index, edge_distance_vec): + edge_vec_0 = edge_distance_vec + edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) + + # Make sure the atoms are far enough apart + if torch.min(edge_vec_0_distance) < 0.0001: + logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") + (minval, minidx) = torch.min(edge_vec_0_distance, 0) + logging.error( + f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" + ) + + norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) + + edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 + edge_vec_2 = edge_vec_2 / (torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1)) + # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x + # With two 90 degree rotated vectors, at least one should not be aligned with norm_x + edge_vec_2b = edge_vec_2.clone() + edge_vec_2b[:, 0] = -edge_vec_2[:, 1] + edge_vec_2b[:, 1] = edge_vec_2[:, 0] + edge_vec_2c = edge_vec_2.clone() + edge_vec_2c[:, 1] = -edge_vec_2[:, 2] + edge_vec_2c[:, 2] = edge_vec_2[:, 1] + vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) + vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) + + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) + edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) + edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) + + vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) + # Check the vectors aren't aligned + assert torch.max(vec_dot) < 0.99 + + norm_z = torch.cross(norm_x, edge_vec_2, dim=1) + norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) + norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) + norm_y = torch.cross(norm_x, norm_z, dim=1) + norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) + + # Construct the 3D rotation matrix + norm_x = norm_x.view(-1, 3, 1) + norm_y = -norm_y.view(-1, 3, 1) + norm_z = norm_z.view(-1, 3, 1) + + edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) + edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) + + return edge_rot_mat.detach() diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index c17b8bda7..350e099f4 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -15,6 +15,11 @@ import torch import torch.nn as nn +from fairchem.core.models.escn.edge_rot_mat import ( + init_edge_rot_mat_new, + init_edge_rot_mat_og, +) + if typing.TYPE_CHECKING: from torch_geometric.data.batch import Batch @@ -89,6 +94,7 @@ def __init__( show_timing_info: bool = False, resolution: int | None = None, activation_checkpoint: bool | None = False, + edge_rot_mat: str = "og", ) -> None: if mmax_list is None: mmax_list = [2] @@ -126,6 +132,11 @@ def __init__( self.basis_width_scalar = basis_width_scalar self.distance_function = distance_function + if edge_rot_mat == "og": + self._init_edge_rot_mat = init_edge_rot_mat_og + else: + self._init_edge_rot_mat = init_edge_rot_mat_new + # variables used for display purposes self.counter = 0 @@ -365,63 +376,6 @@ def forward(self, data): return outputs - # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) - - # Make sure the atoms are far enough apart - if torch.min(edge_vec_0_distance) < 0.0001: - logging.error( - f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}" - ) - (minval, minidx) = torch.min(edge_vec_0_distance, 0) - logging.error( - f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" - ) - - norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) - - edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 - edge_vec_2 = edge_vec_2 / ( - torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1) - ) - # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.clone() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.clone() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - assert torch.max(vec_dot) < 0.99 - - norm_z = torch.cross(norm_x, edge_vec_2, dim=1) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) - norm_y = torch.cross(norm_x, norm_z, dim=1) - norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) - - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - - edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) - edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - - return edge_rot_mat.detach() - @property def num_params(self) -> int: return sum(p.numel() for p in self.parameters()) From ad2a06c7a8fb91ac02bd114232e72f74f9f19b77 Mon Sep 17 00:00:00 2001 From: Misko Date: Fri, 25 Oct 2024 17:15:47 +0000 Subject: [PATCH 2/3] update equiformer v2 as well --- .../core/models/equiformer_v2/edge_rot_mat.py | 55 ------------------ .../models/equiformer_v2/equiformer_v2.py | 10 +--- src/fairchem/core/models/escn/edge_rot_mat.py | 57 +------------------ src/fairchem/core/models/escn/escn.py | 14 +---- 4 files changed, 6 insertions(+), 130 deletions(-) delete mode 100644 src/fairchem/core/models/equiformer_v2/edge_rot_mat.py diff --git a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py b/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py deleted file mode 100644 index c83cc3143..000000000 --- a/src/fairchem/core/models/equiformer_v2/edge_rot_mat.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -import logging - -import torch - - -def init_edge_rot_mat(edge_distance_vec): - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) - - # Make sure the atoms are far enough apart - # assert torch.min(edge_vec_0_distance) < 0.0001 - if torch.min(edge_vec_0_distance) < 0.0001: - logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") - - norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) - - edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 - edge_vec_2 = edge_vec_2 / (torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1)) - # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.clone() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.clone() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - assert torch.max(vec_dot) < 0.99 - - norm_z = torch.cross(norm_x, edge_vec_2, dim=1) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) - norm_y = torch.cross(norm_x, norm_z, dim=1) - norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) - - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - - edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) - edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - - return edge_rot_mat.detach() diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 98e21a77f..d96b1ca9a 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -15,6 +15,7 @@ GraphModelMixin, HeadInterface, ) +from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): @@ -23,7 +24,6 @@ import typing -from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding from .layer_norm import ( @@ -443,9 +443,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -569,10 +567,6 @@ def _init_gp_partitions( edge_distance_vec, ) - # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): - return init_edge_rot_mat(edge_distance_vec) - @property def num_params(self): return sum(p.numel() for p in self.parameters()) diff --git a/src/fairchem/core/models/escn/edge_rot_mat.py b/src/fairchem/core/models/escn/edge_rot_mat.py index 2c4b4c6c4..72c960a4a 100644 --- a/src/fairchem/core/models/escn/edge_rot_mat.py +++ b/src/fairchem/core/models/escn/edge_rot_mat.py @@ -37,10 +37,9 @@ def vec3_rotate_around_axis(v, axis, thetas): return v + s_theta * Kv + (1 - c_theta) * KKv -def init_edge_rot_mat_new(data, edge_index, edge_distance_vec): +def init_edge_rot_mat(edge_distance_vec): edge_vec_0 = edge_distance_vec.detach() edge_vec_0_distance = torch.linalg.norm(edge_vec_0, axis=1, keepdim=True) - # Make sure the atoms are far enough apart # assert torch.min(edge_vec_0_distance) < 0.0001 if torch.min(edge_vec_0_distance) < 0.0001: @@ -67,57 +66,3 @@ def init_edge_rot_mat_new(data, edge_index, edge_distance_vec): norm_y = -norm_y.view(-1, 1, 3) norm_z = norm_z.view(-1, 1, 3) return torch.cat([norm_z, norm_x, norm_y], dim=1).contiguous() - - -# Initialize the edge rotation matrics -def init_edge_rot_mat_og(data, edge_index, edge_distance_vec): - edge_vec_0 = edge_distance_vec - edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1)) - - # Make sure the atoms are far enough apart - if torch.min(edge_vec_0_distance) < 0.0001: - logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}") - (minval, minidx) = torch.min(edge_vec_0_distance, 0) - logging.error( - f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}" - ) - - norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1)) - - edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5 - edge_vec_2 = edge_vec_2 / (torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1)) - # Create two rotated copys of the random vectors in case the random vector is aligned with norm_x - # With two 90 degree rotated vectors, at least one should not be aligned with norm_x - edge_vec_2b = edge_vec_2.clone() - edge_vec_2b[:, 0] = -edge_vec_2[:, 1] - edge_vec_2b[:, 1] = edge_vec_2[:, 0] - edge_vec_2c = edge_vec_2.clone() - edge_vec_2c[:, 1] = -edge_vec_2[:, 2] - edge_vec_2c[:, 2] = edge_vec_2[:, 1] - vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1) - vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2) - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1) - edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2) - - vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)) - # Check the vectors aren't aligned - assert torch.max(vec_dot) < 0.99 - - norm_z = torch.cross(norm_x, edge_vec_2, dim=1) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True))) - norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1)) - norm_y = torch.cross(norm_x, norm_z, dim=1) - norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True))) - - # Construct the 3D rotation matrix - norm_x = norm_x.view(-1, 3, 1) - norm_y = -norm_y.view(-1, 3, 1) - norm_z = norm_z.view(-1, 3, 1) - - edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2) - edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2) - - return edge_rot_mat.detach() diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 350e099f4..ef734b0e3 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -16,8 +16,7 @@ import torch.nn as nn from fairchem.core.models.escn.edge_rot_mat import ( - init_edge_rot_mat_new, - init_edge_rot_mat_og, + init_edge_rot_mat, ) if typing.TYPE_CHECKING: @@ -132,11 +131,6 @@ def __init__( self.basis_width_scalar = basis_width_scalar self.distance_function = distance_function - if edge_rot_mat == "og": - self._init_edge_rot_mat = init_edge_rot_mat_og - else: - self._init_edge_rot_mat = init_edge_rot_mat_new - # variables used for display purposes self.counter = 0 @@ -259,9 +253,7 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() @@ -399,7 +391,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( + edge_rot_mat = init_edge_rot_mat( data, graph.edge_index, graph.edge_distance_vec ) From 7658806fa6afdd1d290e7c1f12862309296bb39d Mon Sep 17 00:00:00 2001 From: Misko Date: Mon, 28 Oct 2024 00:48:41 +0000 Subject: [PATCH 3/3] update syrupy --- .../models/equiformer_v2/equiformer_v2_deprecated.py | 10 ++-------- src/fairchem/core/models/escn/escn.py | 4 +--- .../__snapshots__/test_equiformer_v2_deprecated.ambr | 12 ++++++------ 3 files changed, 9 insertions(+), 17 deletions(-) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py index 1da2ed3ad..40a5f1d5d 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_deprecated.py @@ -11,13 +11,13 @@ from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.models.base import GraphModelMixin +from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat from fairchem.core.models.scn.smearing import GaussianSmearing with contextlib.suppress(ImportError): pass -from .edge_rot_mat import init_edge_rot_mat from .gaussian_rbf import GaussianRadialBasisLayer from .input_block import EdgeDegreeEmbedding from .layer_norm import ( @@ -484,9 +484,7 @@ def forward(self, data): ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = self._init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations for i in range(self.num_resolutions): @@ -618,10 +616,6 @@ def forward(self, data): return outputs - # Initialize the edge rotation matrics - def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec): - return init_edge_rot_mat(edge_distance_vec) - @property def num_params(self): return sum(p.numel() for p in self.parameters()) diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index ef734b0e3..9aeb31ac6 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -391,9 +391,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: ############################################################### # Compute 3x3 rotation matrix per edge - edge_rot_mat = init_edge_rot_mat( - data, graph.edge_index, graph.edge_distance_vec - ) + edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec) # Initialize the WignerD matrices and other values for spherical harmonic calculations self.SO3_edge_rot = nn.ModuleList() diff --git a/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr b/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr index d374d616e..352a5859d 100644 --- a/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr +++ b/tests/core/models/__snapshots__/test_equiformer_v2_deprecated.ambr @@ -6,7 +6,7 @@ # --- # name: TestEquiformerV2.test_ddp.1 Approx( - array([0.12408739], dtype=float32), + array([-0.00897979], dtype=float32), rtol=0.001, atol=0.001 ) @@ -19,7 +19,7 @@ # --- # name: TestEquiformerV2.test_ddp.3 Approx( - array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32), + array([-0.00893596, -0.00290774, -0.02622147], dtype=float32), rtol=0.001, atol=0.001 ) @@ -31,7 +31,7 @@ # --- # name: TestEquiformerV2.test_energy_force_shape.1 Approx( - array([0.12408739], dtype=float32), + array([-0.00897979], dtype=float32), rtol=0.001, atol=0.001 ) @@ -44,7 +44,7 @@ # --- # name: TestEquiformerV2.test_energy_force_shape.3 Approx( - array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32), + array([-0.00893596, -0.00290774, -0.02622147], dtype=float32), rtol=0.001, atol=0.001 ) @@ -56,7 +56,7 @@ # --- # name: TestEquiformerV2.test_gp.1 Approx( - array([0.12408739], dtype=float32), + array([-0.02495257], dtype=float32), rtol=0.001, atol=0.001 ) @@ -69,7 +69,7 @@ # --- # name: TestEquiformerV2.test_gp.3 Approx( - array([ 1.4928661e-03, -7.4134863e-05, 2.9909245e-03], dtype=float32), + array([ 0.00203055, -0.00042872, -0.00279118], dtype=float32), rtol=0.001, atol=0.001 )