Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix edge_rot_mat #895

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions src/fairchem/core/models/equiformer_v2/edge_rot_mat.py

This file was deleted.

10 changes: 2 additions & 8 deletions src/fairchem/core/models/equiformer_v2/equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
GraphModelMixin,
HeadInterface,
)
from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat
from fairchem.core.models.scn.smearing import GaussianSmearing

with contextlib.suppress(ImportError):
Expand All @@ -23,7 +24,6 @@

import typing

from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
from .layer_norm import (
Expand Down Expand Up @@ -443,9 +443,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
Expand Down Expand Up @@ -569,10 +567,6 @@ def _init_gp_partitions(
edge_distance_vec,
)

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
return init_edge_rot_mat(edge_distance_vec)

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
from fairchem.core.common.registry import registry
from fairchem.core.common.utils import conditional_grad
from fairchem.core.models.base import GraphModelMixin
from fairchem.core.models.escn.edge_rot_mat import init_edge_rot_mat
from fairchem.core.models.scn.smearing import GaussianSmearing

with contextlib.suppress(ImportError):
pass


from .edge_rot_mat import init_edge_rot_mat
from .gaussian_rbf import GaussianRadialBasisLayer
from .input_block import EdgeDegreeEmbedding
from .layer_norm import (
Expand Down Expand Up @@ -484,9 +484,7 @@ def forward(self, data):
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
for i in range(self.num_resolutions):
Expand Down Expand Up @@ -618,10 +616,6 @@ def forward(self, data):

return outputs

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
return init_edge_rot_mat(edge_distance_vec)

@property
def num_params(self):
return sum(p.numel() for p in self.parameters())
Expand Down
68 changes: 68 additions & 0 deletions src/fairchem/core/models/escn/edge_rot_mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import logging
import math

import torch


# Algorithm from Ken Whatmough (https://math.stackexchange.com/users/918128/ken-whatmough)
def vec3_to_perp_vec3(v):
"""
Small proof:
input = x y z
output = s(x)|z| s(y)|z| -s(z)(|x|+|y|)

input dot output
= x*s(x)*|z| + y*s(y)*|z| - z*s(z)*|x| - z*s(z)*|y|
a*s(a)=|a| ,
= |x|*|z| + |y|*|z| - |z|*|x| - |z|*|y| = 0

"""
return torch.hstack(
[
v[:, [2]].copysign(v[:, [0, 1]]),
-v[:, [0, 1]].copysign(v[:, [2]]).sum(axis=1, keepdim=True),
]
)


# https://en.wikipedia.org/wiki/Rodrigues'_rotation_formula#Matrix_notation
def vec3_rotate_around_axis(v, axis, thetas):
# v_rot= v + (sTheta)*(axis X v) + (1-cTheta)*(axis X (axis X v))
Kv = torch.cross(axis, v, dim=1)
KKv = torch.cross(axis, Kv, dim=1)
s_theta = torch.sin(thetas)
c_theta = torch.cos(thetas)
return v + s_theta * Kv + (1 - c_theta) * KKv


def init_edge_rot_mat(edge_distance_vec):
edge_vec_0 = edge_distance_vec.detach()
edge_vec_0_distance = torch.linalg.norm(edge_vec_0, axis=1, keepdim=True)
# Make sure the atoms are far enough apart
# assert torch.min(edge_vec_0_distance) < 0.0001
if torch.min(edge_vec_0_distance) < 0.0001:
logging.error(f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}")

norm_x = edge_vec_0 / edge_vec_0_distance

perp_to_norm_x = vec3_to_perp_vec3(norm_x)
random_rotated_in_plane_perp_to_norm_x = vec3_rotate_around_axis(
perp_to_norm_x,
norm_x,
torch.rand((norm_x.shape[0], 1), device=norm_x.device) * 2 * math.pi,
)

norm_z = random_rotated_in_plane_perp_to_norm_x / torch.linalg.norm(
random_rotated_in_plane_perp_to_norm_x, axis=1, keepdim=True
)

norm_y = torch.cross(norm_x, norm_z, dim=1)
norm_y /= torch.linalg.norm(norm_y, dim=1, keepdim=True)

# Construct the 3D rotation matrix
norm_x = norm_x.view(-1, 1, 3)
norm_y = -norm_y.view(-1, 1, 3)
norm_z = norm_z.view(-1, 1, 3)
return torch.cat([norm_z, norm_x, norm_y], dim=1).contiguous()
70 changes: 7 additions & 63 deletions src/fairchem/core/models/escn/escn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
import torch
import torch.nn as nn

from fairchem.core.models.escn.edge_rot_mat import (
init_edge_rot_mat,
)

if typing.TYPE_CHECKING:
from torch_geometric.data.batch import Batch

Expand Down Expand Up @@ -89,6 +93,7 @@ def __init__(
show_timing_info: bool = False,
resolution: int | None = None,
activation_checkpoint: bool | None = False,
edge_rot_mat: str = "og",
) -> None:
if mmax_list is None:
mmax_list = [2]
Expand Down Expand Up @@ -248,9 +253,7 @@ def forward(self, data):
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
Expand Down Expand Up @@ -365,63 +368,6 @@ def forward(self, data):

return outputs

# Initialize the edge rotation matrics
def _init_edge_rot_mat(self, data, edge_index, edge_distance_vec):
edge_vec_0 = edge_distance_vec
edge_vec_0_distance = torch.sqrt(torch.sum(edge_vec_0**2, dim=1))

# Make sure the atoms are far enough apart
if torch.min(edge_vec_0_distance) < 0.0001:
logging.error(
f"Error edge_vec_0_distance: {torch.min(edge_vec_0_distance)}"
)
(minval, minidx) = torch.min(edge_vec_0_distance, 0)
logging.error(
f"Error edge_vec_0_distance: {minidx} {edge_index[0, minidx]} {edge_index[1, minidx]} {data.pos[edge_index[0, minidx]]} {data.pos[edge_index[1, minidx]]}"
)

norm_x = edge_vec_0 / (edge_vec_0_distance.view(-1, 1))

edge_vec_2 = torch.rand_like(edge_vec_0) - 0.5
edge_vec_2 = edge_vec_2 / (
torch.sqrt(torch.sum(edge_vec_2**2, dim=1)).view(-1, 1)
)
# Create two rotated copys of the random vectors in case the random vector is aligned with norm_x
# With two 90 degree rotated vectors, at least one should not be aligned with norm_x
edge_vec_2b = edge_vec_2.clone()
edge_vec_2b[:, 0] = -edge_vec_2[:, 1]
edge_vec_2b[:, 1] = edge_vec_2[:, 0]
edge_vec_2c = edge_vec_2.clone()
edge_vec_2c[:, 1] = -edge_vec_2[:, 2]
edge_vec_2c[:, 2] = edge_vec_2[:, 1]
vec_dot_b = torch.abs(torch.sum(edge_vec_2b * norm_x, dim=1)).view(-1, 1)
vec_dot_c = torch.abs(torch.sum(edge_vec_2c * norm_x, dim=1)).view(-1, 1)

vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1)
edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_b), edge_vec_2b, edge_vec_2)
vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1)).view(-1, 1)
edge_vec_2 = torch.where(torch.gt(vec_dot, vec_dot_c), edge_vec_2c, edge_vec_2)

vec_dot = torch.abs(torch.sum(edge_vec_2 * norm_x, dim=1))
# Check the vectors aren't aligned
assert torch.max(vec_dot) < 0.99

norm_z = torch.cross(norm_x, edge_vec_2, dim=1)
norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1, keepdim=True)))
norm_z = norm_z / (torch.sqrt(torch.sum(norm_z**2, dim=1)).view(-1, 1))
norm_y = torch.cross(norm_x, norm_z, dim=1)
norm_y = norm_y / (torch.sqrt(torch.sum(norm_y**2, dim=1, keepdim=True)))

# Construct the 3D rotation matrix
norm_x = norm_x.view(-1, 3, 1)
norm_y = -norm_y.view(-1, 3, 1)
norm_z = norm_z.view(-1, 3, 1)

edge_rot_mat_inv = torch.cat([norm_z, norm_x, norm_y], dim=2)
edge_rot_mat = torch.transpose(edge_rot_mat_inv, 1, 2)

return edge_rot_mat.detach()

@property
def num_params(self) -> int:
return sum(p.numel() for p in self.parameters())
Expand All @@ -445,9 +391,7 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
###############################################################

# Compute 3x3 rotation matrix per edge
edge_rot_mat = self._init_edge_rot_mat(
data, graph.edge_index, graph.edge_distance_vec
)
edge_rot_mat = init_edge_rot_mat(graph.edge_distance_vec)

# Initialize the WignerD matrices and other values for spherical harmonic calculations
self.SO3_edge_rot = nn.ModuleList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# ---
# name: TestEquiformerV2.test_ddp.1
Approx(
array([0.12408739], dtype=float32),
array([-0.00897979], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -19,7 +19,7 @@
# ---
# name: TestEquiformerV2.test_ddp.3
Approx(
array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32),
array([-0.00893596, -0.00290774, -0.02622147], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -31,7 +31,7 @@
# ---
# name: TestEquiformerV2.test_energy_force_shape.1
Approx(
array([0.12408739], dtype=float32),
array([-0.00897979], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -44,7 +44,7 @@
# ---
# name: TestEquiformerV2.test_energy_force_shape.3
Approx(
array([ 1.4928584e-03, -7.4167408e-05, 2.9909366e-03], dtype=float32),
array([-0.00893596, -0.00290774, -0.02622147], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -56,7 +56,7 @@
# ---
# name: TestEquiformerV2.test_gp.1
Approx(
array([0.12408739], dtype=float32),
array([-0.02495257], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand All @@ -69,7 +69,7 @@
# ---
# name: TestEquiformerV2.test_gp.3
Approx(
array([ 1.4928661e-03, -7.4134863e-05, 2.9909245e-03], dtype=float32),
array([ 0.00203055, -0.00042872, -0.00279118], dtype=float32),
rtol=0.001,
atol=0.001
)
Expand Down