Skip to content

Commit

Permalink
feat: add orb-models support to mlp recipes (#2574)
Browse files Browse the repository at this point in the history
## Summary of Changes
Added **orb-models** to the MLP recipes based on issue #2429. 

Key changes include:

1. Added orb-models implementation to mlp recipes
2. Created unit tests to validate orb-models functionality
3. Ensured compatibility with existing MLP workflows

## Requirements

-[X] My PR is focused on a single feature addition [#2429 ].
-[X] My PR has relevant, comprehensive unit tests.
-[X] My PR is on a custom branch (feature/orbnet-models).

## Note to the reviewer: 
1. orb-models outputs are somewhat more stochastic than other tests, so
I used a `rel` flag to limit the accuracy of the pytest comparison. I
set it to `1e-4` after some failed checks.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mamunm and pre-commit-ci[bot] authored Dec 23, 2024
1 parent d771f5f commit 506c157
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-clou
dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"]
defects = ["pymatgen-analysis-defects>=2024.10.22", "shakenbreak>=3.2.0"]
jobflow = ["jobflow[fireworks]>=0.1.14", "jobflow-remote>=0.1.0"]
mlp = ["matgl>=1.1.2", "chgnet>=0.3.3", "mace-torch>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1"]
mlp = ["matgl>=1.1.2", "chgnet>=0.3.3", "mace-torch>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1", "orb-models>=4.1.0"]
mp = ["atomate2>=0.0.14"]
newtonnet = ["newtonnet>=1.1"]
parsl = ["parsl[monitoring]>=2024.5.27; platform_system!='Windows'"]
Expand Down
29 changes: 25 additions & 4 deletions src/quacc/recipes/mlp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from functools import lru_cache
from importlib.util import find_spec
from logging import getLogger
from typing import TYPE_CHECKING

Expand All @@ -16,21 +17,26 @@

@lru_cache
def pick_calculator(
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"], **kwargs
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"], **kwargs
) -> Calculator:
"""
Adapted from `matcalc.util.get_universal_calculator`.
!!! Note
To use `orb` method, `pynanoflann` must be installed. To install `pynanoflann`,
run `pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann"`.
Parameters
----------
method
Name of the calculator to use
Name of the calculator to use.
**kwargs
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
`matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator`
calculators.
`matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, or
`orb_models.forcefield.calculator.ORBCalculator` calculators.
Returns
-------
Expand Down Expand Up @@ -71,6 +77,21 @@ def pick_calculator(

calc = SevenNetCalculator(**kwargs)

elif method.lower() == "orb":
if not find_spec("pynanoflann"):
raise ImportError(
"""orb-models requires pynanoflann.
Install pynanoflann with `pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann"`.
"""
)
from orb_models import __version__
from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator

orb_model = kwargs.get("model", "orb_v2")
orbff = getattr(pretrained, orb_model)()
calc = ORBCalculator(model=orbff, **kwargs)

else:
raise ValueError(f"Unrecognized {method=}.")

Expand Down
11 changes: 6 additions & 5 deletions src/quacc/recipes/mlp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@job
def static_job(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"],
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"],
properties: list[str] | None = None,
additional_fields: dict[str, Any] | None = None,
**calc_kwargs,
Expand All @@ -43,7 +43,8 @@ def static_job(
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
`matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator` calculators.
`matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, or
`orb_models.forcefield.calculator.ORBCalculator` calculators.
Returns
-------
Expand All @@ -63,7 +64,7 @@ def static_job(
@job
def relax_job(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"],
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"],
relax_cell: bool = False,
opt_params: OptParams | None = None,
additional_fields: dict[str, Any] | None = None,
Expand All @@ -89,8 +90,8 @@ def relax_job(
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
`matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator`
calculators.
`matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, or
`orb_models.forcefield.calculator.ORBCalculator` calculators.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/quacc/recipes/mlp/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
def phonon_flow(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"],
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"],
symprec: float = 1e-4,
min_lengths: float | tuple[float, float, float] | None = 20.0,
supercell_matrix: (
Expand Down
27 changes: 24 additions & 3 deletions tests/core/recipes/mlp_recipes/test_core_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
if has_sevennet := find_spec("sevenn"):
methods.append("sevennet")

if has_orb := find_spec("orb_models"):
methods.append("orb")


@pytest.mark.skipif(has_chgnet is None, reason="chgnet not installed")
def test_bad_method():
Expand Down Expand Up @@ -52,14 +55,30 @@ def test_static_job(tmp_path, monkeypatch, method):
"m3gnet": -4.0938973,
"mace-mp-0": -4.083906650543213,
"sevennet": -4.096191883087158,
"orb": -4.093477725982666,
}
atoms = bulk("Cu")
output = static_job(atoms, method=method)
assert output["results"]["energy"] == pytest.approx(ref_energy[method])
assert output["results"]["energy"] == pytest.approx(ref_energy[method], rel=1e-4)
assert np.shape(output["results"]["forces"]) == (1, 3)
assert output["atoms"] == atoms


def test_relax_job_missing_pynanoflann(monkeypatch):
def mock_find_spec(name):
if name == "pynanoflann":
return None
return find_spec(name)

import quacc.recipes.mlp._base

quacc.recipes.mlp._base.pick_calculator.cache_clear()
monkeypatch.setattr("importlib.util.find_spec", mock_find_spec)
monkeypatch.setattr("quacc.recipes.mlp._base.find_spec", mock_find_spec)
with pytest.raises(ImportError, match=r"orb-models requires pynanoflann"):
relax_job(bulk("Cu"), method="orb")


@pytest.mark.parametrize("method", methods)
def test_relax_job(tmp_path, monkeypatch, method):
monkeypatch.chdir(tmp_path)
Expand All @@ -73,12 +92,13 @@ def test_relax_job(tmp_path, monkeypatch, method):
"m3gnet": -32.75003433227539,
"mace-mp-0": -32.6711566550002,
"sevennet": -32.76924133300781,
"orb": -32.7361946105957,
}

atoms = bulk("Cu") * (2, 2, 2)
atoms[0].position += 0.1
output = relax_job(atoms, method=method)
assert output["results"]["energy"] == pytest.approx(ref_energy[method])
assert output["results"]["energy"] == pytest.approx(ref_energy[method], rel=1e-4)
assert np.shape(output["results"]["forces"]) == (8, 3)
assert output["atoms"] != atoms
assert output["atoms"].get_volume() == pytest.approx(atoms.get_volume())
Expand Down Expand Up @@ -114,12 +134,13 @@ def test_relax_cell_job(tmp_path, monkeypatch, method):
"m3gnet": -32.750858306884766,
"mace-mp-0": -32.67840391814377,
"sevennet": -32.76963806152344,
"orb": -32.73428726196289,
}

atoms = bulk("Cu") * (2, 2, 2)
atoms[0].position += 0.1
output = relax_job(atoms, method=method, relax_cell=True)
assert output["results"]["energy"] == pytest.approx(ref_energy[method])
assert output["results"]["energy"] == pytest.approx(ref_energy[method], rel=1e-4)
assert np.shape(output["results"]["forces"]) == (8, 3)
assert output["atoms"] != atoms
assert output["atoms"].get_volume() != pytest.approx(atoms.get_volume())
2 changes: 2 additions & 0 deletions tests/requirements-mlp.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ chgnet==0.4.0
mace-torch==0.3.9
torch-dftd==0.5.1
sevenn==0.10.3
orb-models==4.1.0
pynanoflann@git+https://github.com/dwastberg/pynanoflann

0 comments on commit 506c157

Please sign in to comment.