Skip to content

Commit

Permalink
Add in tests for coefficient mapping, mprimary, lprimary (#626)
Browse files Browse the repository at this point in the history
  • Loading branch information
misko authored Mar 8, 2024
1 parent 595995a commit 9108a87
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
75 changes: 75 additions & 0 deletions tests/models/test_equiformer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import load_state_dict, setup_imports
from ocpmodels.datasets import data_list_collater
from ocpmodels.models.equiformer_v2.so3 import (
CoefficientMappingModule,
SO3_Embedding,
)
from ocpmodels.preprocessing import AtomsToGraphs


Expand Down Expand Up @@ -117,3 +121,74 @@ def test_energy_force_shape(self, snapshot):

assert snapshot == forces.shape
assert snapshot == pytest.approx(forces.detach().mean(0))


class TestMPrimaryLPrimary:
def test_mprimary_lprimary_mappings(self):
def sign(x):
return 1 if x >= 0 else -1

device = torch.device("cpu")
lmax_list = [6, 8]
mmax_list = [3, 6]
for lmax in lmax_list:
for mmax in mmax_list:
c = CoefficientMappingModule([lmax], [mmax])

embedding = SO3_Embedding(
length=1,
lmax_list=[lmax],
num_channels=1,
device=device,
dtype=torch.float32,
)

"""
Generate L_primary matrix
L0: 0.00 ~ L0M0
L1: -1.01 1.00 1.01 ~ L1M(-1),L1M0,L1M1
L2: -2.02 -2.01 2.00 2.01 2.02 ~ L2M(-2),L2M(-1),L2M0,L2M1,L2M2
"""
test_matrix_lp = []
for l in range(lmax + 1):
max_m = min(l, mmax)
for m in range(-max_m, max_m + 1):
v = l * sign(m) + 0.01 * m # +/- l . 00 m
test_matrix_lp.append(v)

test_matrix_lp = (
torch.tensor(test_matrix_lp)
.reshape(1, -1, 1)
.to(torch.float32)
)

"""
Generate M_primary matrix
M0: 0.00 , 1.00, 2.00, ... , LMax ~ M0L0, M0L1, .., M0L(LMax)
M1: 1.01, 2.01, .., LMax.01, -1.01, -2.01, -LMax.01 ~ L1M1, L2M1, .., L(LMax)M1, L1M(-1), L2M(-1), ... , L(LMax)M(-1)
"""
test_matrix_mp = []
for m in range(max_m + 1):
for l in range(m, lmax + 1):
v = l + 0.01 * m # +/- l . 00 m
test_matrix_mp.append(v)
if m > 0:
for l in range(m, lmax + 1):
v = -(l + 0.01 * m) # +/- l . 00 m
test_matrix_mp.append(v)

test_matrix_mp = (
torch.tensor(test_matrix_mp)
.reshape(1, -1, 1)
.to(torch.float32)
)

embedding.embedding = test_matrix_lp.clone()

embedding._m_primary(c)
mp = embedding.embedding.clone()
(test_matrix_mp == mp).all()

embedding._l_primary(c)
lp = embedding.embedding.clone()
(test_matrix_lp == lp).all()
75 changes: 75 additions & 0 deletions tests/models/test_escn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch

from ocpmodels.models.escn.so3 import CoefficientMapping
from ocpmodels.models.escn.so3 import SO3_Embedding as escn_SO3_Embedding


class TestMPrimaryLPrimary:
def test_mprimary_lprimary_mappings(self):
def sign(x):
return 1 if x >= 0 else -1

device = torch.device("cpu")
lmax_list = [6, 8]
mmax_list = [3, 6]
for lmax in lmax_list:
for mmax in mmax_list:
c = CoefficientMapping([lmax], [mmax], device=device)

escn_embedding = escn_SO3_Embedding(
length=1,
lmax_list=[lmax],
num_channels=1,
device=device,
dtype=torch.float32,
)

"""
Generate L_primary matrix
L0: 0.00 ~ L0M0
L1: -1.01 1.00 1.01 ~ L1M(-1),L1M0,L1M1
L2: -2.02 -2.01 2.00 2.01 2.02 ~ L2M(-2),L2M(-1),L2M0,L2M1,L2M2
"""
test_matrix_lp = []
for l in range(lmax + 1):
max_m = min(l, mmax)
for m in range(-max_m, max_m + 1):
v = l * sign(m) + 0.01 * m # +/- l . 00 m
test_matrix_lp.append(v)

test_matrix_lp = (
torch.tensor(test_matrix_lp)
.reshape(1, -1, 1)
.to(torch.float32)
)

"""
Generate M_primary matrix
M0: 0.00 , 1.00, 2.00, ... , LMax ~ M0L0, M0L1, .., M0L(LMax)
M1: 1.01, 2.01, .., LMax.01, -1.01, -2.01, -LMax.01 ~ L1M1, L2M1, .., L(LMax)M1, L1M(-1), L2M(-1), ... , L(LMax)M(-1)
"""
test_matrix_mp = []
for m in range(max_m + 1):
for l in range(m, lmax + 1):
v = l + 0.01 * m # +/- l . 00 m
test_matrix_mp.append(v)
if m > 0:
for l in range(m, lmax + 1):
v = -(l + 0.01 * m) # +/- l . 00 m
test_matrix_mp.append(v)

test_matrix_mp = (
torch.tensor(test_matrix_mp)
.reshape(1, -1, 1)
.to(torch.float32)
)

escn_embedding.embedding = test_matrix_lp.clone()

escn_embedding._m_primary(c)
mp = escn_embedding.embedding.clone()
(test_matrix_mp == mp).all()

escn_embedding._l_primary(c)
lp = escn_embedding.embedding.clone()
(test_matrix_lp == lp).all()

0 comments on commit 9108a87

Please sign in to comment.