Skip to content

Commit

Permalink
feat: add sevennet support to mlp recipes (#2553)
Browse files Browse the repository at this point in the history
## Summary of Changes
Added **sevennet** to the MLP recipes based on issue #2429. 

Key changes include:

1. Added sevennet implementation to mlp recipes
2. Created unit tests to validate sevennet functionality
3. Ensured compatibility with existing MLP workflows

## Requirements

-[X] My PR is focused on a single feature addition [#2429 ].
-[X] My PR has relevant, comprehensive unit tests.
-[X] My PR is on a custom branch (feature/sevennetMLP).

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mamunm and pre-commit-ci[bot] authored Dec 1, 2024
1 parent 4696bd8 commit e61f75d
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ covalent = ["covalent>=0.234.1-rc.0; platform_system!='Windows'", "covalent-clou
dask = ["dask[distributed]>=2023.12.1", "dask-jobqueue>=0.8.2"]
defects = ["pymatgen-analysis-defects>=2023.8.22", "shakenbreak>=3.2.0"]
jobflow = ["jobflow[fireworks]>=0.1.14", "jobflow-remote>=0.1.0"]
mlp = ["matgl>=1.1.2", "chgnet>=0.3.3", "mace-torch>=0.3.3", "torch-dftd>=0.4.0"]
mlp = ["matgl>=1.1.2", "chgnet>=0.3.3", "mace-torch>=0.3.3", "torch-dftd>=0.4.0", "sevenn>=0.10.1"]
mp = ["atomate2>=0.0.14"]
newtonnet = ["newtonnet>=1.1"]
parsl = ["parsl[monitoring]>=2024.5.27; platform_system!='Windows'"]
Expand Down
11 changes: 9 additions & 2 deletions src/quacc/recipes/mlp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@lru_cache
def pick_calculator(
method: Literal["mace-mp-0", "m3gnet", "chgnet"], **kwargs
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"], **kwargs
) -> Calculator:
"""
Adapted from `matcalc.util.get_universal_calculator`.
Expand All @@ -29,7 +29,8 @@ def pick_calculator(
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
or `matgl.ext.ase.M3GNetCalculator` calculators.
`matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator`
calculators.
Returns
-------
Expand Down Expand Up @@ -64,6 +65,12 @@ def pick_calculator(
kwargs["default_dtype"] = "float64"
calc = mace_mp(**kwargs)

elif method.lower() == "sevennet":
from sevenn import __version__
from sevenn.sevennet_calculator import SevenNetCalculator

calc = SevenNetCalculator(**kwargs)

else:
raise ValueError(f"Unrecognized {method=}.")

Expand Down
9 changes: 5 additions & 4 deletions src/quacc/recipes/mlp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@job
def static_job(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet"],
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"],
properties: list[str] | None = None,
additional_fields: dict[str, Any] | None = None,
**calc_kwargs,
Expand All @@ -43,7 +43,7 @@ def static_job(
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
or `matgl.ext.ase.M3GNetCalculator` calculators.
`matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator` calculators.
Returns
-------
Expand All @@ -63,7 +63,7 @@ def static_job(
@job
def relax_job(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet"],
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"],
relax_cell: bool = False,
opt_params: OptParams | None = None,
additional_fields: dict[str, Any] | None = None,
Expand All @@ -89,7 +89,8 @@ def relax_job(
Custom kwargs for the underlying calculator. Set a value to
`quacc.Remove` to remove a pre-existing key entirely. For a list of available
keys, refer to the `mace.calculators.mace_mp`, `chgnet.model.dynamics.CHGNetCalculator`,
or `matgl.ext.ase.M3GNetCalculator` calculators.
`matgl.ext.ase.M3GNetCalculator`, or `sevenn.sevennet_calculator.SevenNetCalculator`
calculators.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/quacc/recipes/mlp/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
def phonon_flow(
atoms: Atoms,
method: Literal["mace-mp-0", "m3gnet", "chgnet"],
method: Literal["mace-mp-0", "m3gnet", "chgnet", "sevennet"],
symprec: float = 1e-4,
min_lengths: float | tuple[float, float, float] | None = 20.0,
supercell_matrix: (
Expand Down
6 changes: 6 additions & 0 deletions tests/core/recipes/mlp_recipes/test_core_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
if has_chgnet := find_spec("chgnet"):
methods.append("chgnet")

if has_sevennet := find_spec("sevenn"):
methods.append("sevennet")


@pytest.mark.skipif(has_chgnet is None, reason="chgnet not installed")
def test_bad_method():
Expand Down Expand Up @@ -48,6 +51,7 @@ def test_static_job(tmp_path, monkeypatch, method):
"chgnet": -4.083308219909668,
"m3gnet": -4.0938973,
"mace-mp-0": -4.083906650543213,
"sevennet": -4.096191883087158,
}
atoms = bulk("Cu")
output = static_job(atoms, method=method)
Expand All @@ -68,6 +72,7 @@ def test_relax_job(tmp_path, monkeypatch, method):
"chgnet": -32.665428161621094,
"m3gnet": -32.75003433227539,
"mace-mp-0": -32.6711566550002,
"sevennet": -32.76924133300781,
}

atoms = bulk("Cu") * (2, 2, 2)
Expand Down Expand Up @@ -108,6 +113,7 @@ def test_relax_cell_job(tmp_path, monkeypatch, method):
"chgnet": -32.66698455810547,
"m3gnet": -32.750858306884766,
"mace-mp-0": -32.67840391814377,
"sevennet": -32.76963806152344,
}

atoms = bulk("Cu") * (2, 2, 2)
Expand Down
1 change: 1 addition & 0 deletions tests/requirements-mlp.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
chgnet==0.4.0
mace-torch==0.3.8
torch-dftd==0.5.1
sevenn==0.10.1

0 comments on commit e61f75d

Please sign in to comment.