Skip to content

Commit

Permalink
add seed option for calculator (#624)
Browse files Browse the repository at this point in the history
* add seed option for calculator

* add test setting seed in escn

* break out tests to provide terminal output so circleCI does not time out
  • Loading branch information
misko authored Feb 26, 2024
1 parent 394e9ba commit 595995a
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ jobs:
conda activate ocp-models
pip install -e .
pre-commit install
pytest --cov-report=xml --cov=ocpmodels/ /home/circleci/project/tests
pytest -vv --cov-report=xml --cov=ocpmodels/ /home/circleci/project/tests
- codecov/upload:
file: coverage.xml

Expand Down
9 changes: 9 additions & 0 deletions ocpmodels/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
cutoff: int = 6,
max_neighbors: int = 50,
cpu: bool = True,
seed: Optional[int] = None,
) -> None:
"""
OCP-ASE Calculator
Expand Down Expand Up @@ -173,6 +174,14 @@ def __init__(
checkpoint_path=checkpoint_path, checkpoint=checkpoint
)

seed = seed if seed is not None else self.trainer.config["cmd"]["seed"]
if seed is None:
logging.warning(
"No seed has been set in modelcheckpoint or OCPCalculator! Results may not be reproducible on re-run"
)
else:
self.trainer.set_seed(seed)

self.a2g = AtomsToGraphs(
max_neigh=max_neighbors,
radius=cutoff,
Expand Down
12 changes: 7 additions & 5 deletions ocpmodels/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,21 @@ def load(self) -> None:
self.load_optimizer()
self.load_extras()

def load_seed_from_config(self) -> None:
def set_seed(self, seed) -> None:
# https://pytorch.org/docs/stable/notes/randomness.html
seed = self.config["cmd"]["seed"]
if seed is None:
return

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def load_seed_from_config(self) -> None:
# https://pytorch.org/docs/stable/notes/randomness.html
if self.config["cmd"]["seed"] is None:
return
self.set_seed(self.config["cmd"]["seed"])

def load_logger(self) -> None:
self.logger = None
if not self.is_debug and distutils.is_master():
Expand Down
2 changes: 1 addition & 1 deletion tests/common/__snapshots__/test_ase_calculator.ambr
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# name: TestCalculator.test_relaxation_final_energy
# name: TestCalculatorRelaxation.test_relaxation_final_energy
0.92
# ---
94 changes: 70 additions & 24 deletions tests/common/test_ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,25 @@ def load_data(request) -> None:
request.cls.atoms = atoms


@pytest.fixture(scope="class")
def load_model_list(request) -> None:
request.cls.model_list = [
def get_with_retry(url, retries=10):
retry = 0
while retry < retries:
try:
r = requests.get(url, timeout=10)
r.raise_for_status()
return r
except ConnectionError as e:
retry += 1
if retry == retries:
raise e
raise ConnectionError


# First let's just make sure all checkpoints are being loaded without any
# errors as part of the ASE calculator setup.
@pytest.mark.parametrize(
"model_url",
[
# SchNet
"https://dl.fbaipublicfiles.com/opencatalystproject/models/2020_11/s2ef/schnet_all_large.pt",
# DimeNet++
Expand All @@ -40,40 +56,70 @@ def load_model_list(request) -> None:
"https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/gemnet_oc_base_s2ef_all_md.pt",
# SCN
"https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/scn_all_md_s2ef.pt",
# eSCN
"https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m3_lay20_all_md_s2ef.pt",
# EquiformerV2
"https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_153M_ec4_allmd.pt",
]


@pytest.mark.usefixtures("load_data")
@pytest.mark.usefixtures("load_model_list")
class TestCalculator:
# First let's just make sure all checkpoints are being loaded without any
# errors as part of the ASE calculator setup.
def test_calculator_setup(self) -> None:
model_list = self.model_list
for url in model_list:
r = requests.get(url, stream=True)
# Equiformer v2 # already tested in test_relaxation_final_energy
# "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_153M_ec4_allmd.pt",
# eSCNm # already tested in test_random_seed_final_energy
# "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m3_lay20_all_md_s2ef.pt",
],
)
class TestCalculatorLoading:
def test_calculator_setup(self, model_url):
with get_with_retry(model_url) as r:
r.raise_for_status()

_ = OCPCalculator(checkpoint_path=io.BytesIO(r.content), cpu=True)


@pytest.mark.usefixtures("load_data")
class TestCalculatorRelaxation:
# Run an adslab relaxation using the ASE calculator and ase.optimize.BFGS
# with one model and compare the final energy.
def test_relaxation_final_energy(self, snapshot) -> None:
random.seed(1)
torch.manual_seed(1)

model_list = self.model_list
r = requests.get(model_list[-1], stream=True)
r.raise_for_status()
calc = OCPCalculator(checkpoint_path=io.BytesIO(r.content), cpu=True)
equiformerv2_url = "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_06/oc20/s2ef/eq2_153M_ec4_allmd.pt"

with get_with_retry(equiformerv2_url) as r:
r.raise_for_status()
calc = OCPCalculator(
checkpoint_path=io.BytesIO(r.content), cpu=True
)

atoms = self.atoms
atoms.set_calculator(calc)
opt = BFGS(atoms)
opt.run(fmax=0.05, steps=100)

assert snapshot == round(atoms.get_potential_energy(), 2)


@pytest.mark.usefixtures("load_data")
class TestCalculatoreSCNSeeds:
def test_random_seed_final_energy(self):
# to big to run on CircleCI on github
# escn_url = "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l6_m3_lay20_all_md_s2ef.pt"
escn_url = "https://dl.fbaipublicfiles.com/opencatalystproject/models/2023_03/s2ef/escn_l4_m2_lay12_2M_s2ef.pt"
seeds = [100, 200, 100]
results_by_seed = {}
# compute the value for each seed , make sure repeated seeds have the exact same output

with get_with_retry(escn_url) as r:
for seed in seeds:
calc = OCPCalculator(
checkpoint_path=io.BytesIO(r.content),
cpu=True,
seed=seed,
)

atoms = self.atoms
atoms.set_calculator(calc)

energy = atoms.get_potential_energy()
if seed in results_by_seed:
assert results_by_seed[seed] == energy
else:
results_by_seed[seed] = energy
# make sure different seeds give slightly different results , expected due to discretization error in grid
for seed_a in set(seeds):
for seed_b in set(seeds) - set([seed_a]):
assert results_by_seed[seed_a] != results_by_seed[seed_b]

0 comments on commit 595995a

Please sign in to comment.