diff --git a/pyproject.toml b/pyproject.toml index b3d7c61eaa..8311d67fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-clou dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"] defects = ["pymatgen-analysis-defects>=2024.10.22", "shakenbreak>=3.2.0"] jobflow = ["jobflow[fireworks]>=0.1.14", "jobflow-remote>=0.1.0"] -mlp = ["matgl>=1.1.2", "chgnet>=0.3.3", "mace-torch>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1"] +mlp = ["matgl>=1.1.2", "chgnet>=0.3.3", "mace-torch>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1", "orb-models>=4.1.0"] mp = ["atomate2>=0.0.14"] newtonnet = ["newtonnet>=1.1"] parsl = ["parsl[monitoring]>=2024.5.27; platform_system!='Windows'"] diff --git a/src/quacc/recipes/mlp/_base.py b/src/quacc/recipes/mlp/_base.py index 4c6922d0b5..2207560f55 100644 --- a/src/quacc/recipes/mlp/_base.py +++ b/src/quacc/recipes/mlp/_base.py @@ -3,6 +3,7 @@ from __future__ import annotations from functools import lru_cache +from importlib.util import find_spec from logging import getLogger from typing import TYPE_CHECKING @@ -16,21 +17,26 @@ @lru_cache def pick_calculator( - method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"], **kwargs + method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"], **kwargs ) -> Calculator: """ Adapted from `matcalc.util.get_universal_calculator`. + !!! Note + + To use `orb` method, `pynanoflann` must be installed. To install `pynanoflann`, + run `pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann"`. + Parameters ---------- method - Name of the calculator to use + Name of the calculator to use. **kwargs Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. For a list of available keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`, - `matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator` - calculators. + `matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, or + `orb_models.forcefield.calculator.ORBCalculator` calculators. Returns ------- @@ -71,6 +77,21 @@ def pick_calculator( calc = SevenNetCalculator(**kwargs) + elif method.lower() == "orb": + if not find_spec("pynanoflann"): + raise ImportError( + """orb-models requires pynanoflann. + Install pynanoflann with `pip install "pynanoflann@git+https://github.com/dwastberg/pynanoflann"`. + """ + ) + from orb_models import __version__ + from orb_models.forcefield import pretrained + from orb_models.forcefield.calculator import ORBCalculator + + orb_model = kwargs.get("model", "orb_v2") + orbff = getattr(pretrained, orb_model)() + calc = ORBCalculator(model=orbff, **kwargs) + else: raise ValueError(f"Unrecognized {method=}.") diff --git a/src/quacc/recipes/mlp/core.py b/src/quacc/recipes/mlp/core.py index 50e5338322..d6c47d47d1 100644 --- a/src/quacc/recipes/mlp/core.py +++ b/src/quacc/recipes/mlp/core.py @@ -21,7 +21,7 @@ @job def static_job( atoms: Atoms, - method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"], + method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"], properties: list[str] | None = None, additional_fields: dict[str, Any] | None = None, **calc_kwargs, @@ -43,7 +43,8 @@ def static_job( Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. For a list of available keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`, - `matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator` calculators. + `matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, or + `orb_models.forcefield.calculator.ORBCalculator` calculators. Returns ------- @@ -63,7 +64,7 @@ def static_job( @job def relax_job( atoms: Atoms, - method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"], + method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"], relax_cell: bool = False, opt_params: OptParams | None = None, additional_fields: dict[str, Any] | None = None, @@ -89,8 +90,8 @@ def relax_job( Custom kwargs for the underlying calculator. Set a value to `quacc.Remove` to remove a pre-existing key entirely. For a list of available keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`, - `matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator` - calculators. + `matgl.ext.ase.M3GNetCalculator`, `sevenn.sevennet_calculator.SevenNetCalculator`, or + `orb_models.forcefield.calculator.ORBCalculator` calculators. Returns ------- diff --git a/src/quacc/recipes/mlp/phonons.py b/src/quacc/recipes/mlp/phonons.py index 09c9feced0..8223432bd9 100644 --- a/src/quacc/recipes/mlp/phonons.py +++ b/src/quacc/recipes/mlp/phonons.py @@ -33,7 +33,7 @@ ) def phonon_flow( atoms: Atoms, - method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"], + method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet", "orb"], symprec: float = 1e-4, min_lengths: float | tuple[float, float, float] | None = 20.0, supercell_matrix: ( diff --git a/tests/core/recipes/mlp_recipes/test_core_recipes.py b/tests/core/recipes/mlp_recipes/test_core_recipes.py index de42214e4d..b25abe99e1 100644 --- a/tests/core/recipes/mlp_recipes/test_core_recipes.py +++ b/tests/core/recipes/mlp_recipes/test_core_recipes.py @@ -24,6 +24,9 @@ if has_sevennet := find_spec("sevenn"): methods.append("sevennet") +if has_orb := find_spec("orb_models"): + methods.append("orb") + @pytest.mark.skipif(has_chgnet is None, reason="chgnet not installed") def test_bad_method(): @@ -52,14 +55,30 @@ def test_static_job(tmp_path, monkeypatch, method): "m3gnet": -4.0938973, "mace-mp-0": -4.083906650543213, "sevennet": -4.096191883087158, + "orb": -4.093477725982666, } atoms = bulk("Cu") output = static_job(atoms, method=method) - assert output["results"]["energy"] == pytest.approx(ref_energy[method]) + assert output["results"]["energy"] == pytest.approx(ref_energy[method], rel=1e-4) assert np.shape(output["results"]["forces"]) == (1, 3) assert output["atoms"] == atoms +def test_relax_job_missing_pynanoflann(monkeypatch): + def mock_find_spec(name): + if name == "pynanoflann": + return None + return find_spec(name) + + import quacc.recipes.mlp._base + + quacc.recipes.mlp._base.pick_calculator.cache_clear() + monkeypatch.setattr("importlib.util.find_spec", mock_find_spec) + monkeypatch.setattr("quacc.recipes.mlp._base.find_spec", mock_find_spec) + with pytest.raises(ImportError, match=r"orb-models requires pynanoflann"): + relax_job(bulk("Cu"), method="orb") + + @pytest.mark.parametrize("method", methods) def test_relax_job(tmp_path, monkeypatch, method): monkeypatch.chdir(tmp_path) @@ -73,12 +92,13 @@ def test_relax_job(tmp_path, monkeypatch, method): "m3gnet": -32.75003433227539, "mace-mp-0": -32.6711566550002, "sevennet": -32.76924133300781, + "orb": -32.7361946105957, } atoms = bulk("Cu") * (2, 2, 2) atoms[0].position += 0.1 output = relax_job(atoms, method=method) - assert output["results"]["energy"] == pytest.approx(ref_energy[method]) + assert output["results"]["energy"] == pytest.approx(ref_energy[method], rel=1e-4) assert np.shape(output["results"]["forces"]) == (8, 3) assert output["atoms"] != atoms assert output["atoms"].get_volume() == pytest.approx(atoms.get_volume()) @@ -114,12 +134,13 @@ def test_relax_cell_job(tmp_path, monkeypatch, method): "m3gnet": -32.750858306884766, "mace-mp-0": -32.67840391814377, "sevennet": -32.76963806152344, + "orb": -32.73428726196289, } atoms = bulk("Cu") * (2, 2, 2) atoms[0].position += 0.1 output = relax_job(atoms, method=method, relax_cell=True) - assert output["results"]["energy"] == pytest.approx(ref_energy[method]) + assert output["results"]["energy"] == pytest.approx(ref_energy[method], rel=1e-4) assert np.shape(output["results"]["forces"]) == (8, 3) assert output["atoms"] != atoms assert output["atoms"].get_volume() != pytest.approx(atoms.get_volume()) diff --git a/tests/requirements-mlp.txt b/tests/requirements-mlp.txt index 7cbdfe2848..583a8262e2 100644 --- a/tests/requirements-mlp.txt +++ b/tests/requirements-mlp.txt @@ -2,3 +2,5 @@ chgnet==0.4.0 mace-torch==0.3.9 torch-dftd==0.5.1 sevenn==0.10.3 +orb-models==4.1.0 +pynanoflann@git+https://github.com/dwastberg/pynanoflann