Skip to content

Commit

Permalink
OptimizableBatch and stress relaxations (#718)
Browse files Browse the repository at this point in the history
* remove r_edges, radius, max_neigh and add deprecation warning

* edit typing and dont use dicts as default

* use super() and remove overkill deprecation warning

* set implemented_properties from config

* make determine step a method

* allow calculator to operate on batches

* only update if old config is used

* reshape properties

* no test classes in ase calculator

* yaml load fix

* use mappingproxy

* expressive import

* remove duplicated code

* optimizable batch class for ase compatible batch relaxations

* fix optimizable batch

* optimizable goodies

* apply force constraints

* use optimizable batch instead and remove torchcalc

* update ml relaxations to use optimizable batch correctly

* force_consistent check for ASE compat

* force_consistent check for ASE compat

* check force_consistent

* init docs in lbfgs

* unitcellfilter for batch relaxations

* ruff

* UnitCellOptimizable as child class instead of filter

* allow running unit cell relaxations

* ruff

* no grad in run_relaxations

* make batched_dot and determine_step methods

* imports

* rename to optimizableunitcellbatch

* allow passing energy and forces explicitly to batch to atoms

* check convergence in optimizable and allow passing general results to atoms_from_batch

* relaxation test

* unit tests

* move update mask to optimizable

* use energy instead of y

* all setting/getting positions and convergence in optimizable

* more (unfinished) tests

* backwards compatible test

* minor fixes

* code cleanup

* add/fix tests

* fix lbfgs

* assert using norm

* add eps to masked batches if using ASE optimizers

* match iterations from previous implementation

* use float64 for forces

* float32

* use energy_relaxed instead of y_relaxed

* energy_relaxed and more explicit error msg

* default to batch_size 1 if not set in config

* keep float64 training

* rename y_relaxed -> energy_relaxed

* rm expcell batch

* convenience commit from no_experimental_resolve

* use numatoms tensor for cell factor

* remove positions tests (wrapping atoms gives different results)

* allow wrapping positions in batch to atoms

* fix test

* wrap_positions in batch_to_atoms

* take a2g properties from model

* test lbfgs traj writes

* remove comments

* use model generate graph

* fix cell_factor

* fix using model in ddp

* fix r_edges in OCPcalculator

* write initial and final structure if save_full is false

* check unique atoms saved in trajectory

* tighter tol

* update ASE release comment

* remove cumulative mask option

* remove left over cumulative_mask

* fix batching when sids as str

* do not try to fetch energy and forces if no explicit results

* accept Path objects

* clean up setting defaults

* expose ml_relax in relaxation

* force set r_pbc True

* make relax_opt optional

* no ema on inference only

* define ema none to avoid issues

* lower force threshold to make sure test does not converge

* clean up exception msg

* allow strings in batch

* remove device argument from lbfgs

* minor cleanup

* fix optimizable import

* do not pass device in ml_relax

* simplify enforce max neighbors

* fix tests (still not testing stress)

* pin sphinx autoapi

* typo in version

---------

Co-authored-by: zulissimeta <[email protected]>
Co-authored-by: Zack Ulissi <[email protected]>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent e11e78e commit 3f695ef
Show file tree
Hide file tree
Showing 17 changed files with 1,068 additions and 241 deletions.
2 changes: 1 addition & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [

[project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev]
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"]
adsorbml = ["dscribe","x3dase","scikit-image"]

[project.scripts]
Expand Down
13 changes: 13 additions & 0 deletions src/fairchem/core/common/relaxation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Copyright (c) Meta, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

from .ml_relaxation import ml_relax
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch

__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"]
108 changes: 77 additions & 31 deletions src/fairchem/core/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

import copy
import logging
from typing import ClassVar
from types import MappingProxyType
from typing import TYPE_CHECKING

import torch
from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.singlepoint import SinglePointCalculator as sp
from ase.calculators.singlepoint import SinglePointCalculator
from ase.constraints import FixAtoms
from ase.geometry import wrap_positions

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import (
Expand All @@ -33,51 +35,93 @@
from fairchem.core.models.model_registry import model_name_to_local_file
from fairchem.core.preprocessing import AtomsToGraphs

if TYPE_CHECKING:
from pathlib import Path

def batch_to_atoms(batch):
from torch_geometric.data import Batch


# system level model predictions have different shapes than expected by ASE
ASE_PROP_RESHAPE = MappingProxyType(
{"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)}
)


def batch_to_atoms(
batch: Batch,
results: dict[str, torch.Tensor] | None = None,
wrap_pos: bool = True,
eps: float = 1e-7,
) -> list[Atoms]:
"""Convert a data batch to ase Atoms
Args:
batch: data batch
results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results
are given no calculator will be added to the atoms objects.
wrap_pos: wrap positions back into the cell.
eps: Small number to prevent slightly negative coordinates from being wrapped.
Returns:
list of Atoms
"""
n_systems = batch.natoms.shape[0]
natoms = batch.natoms.tolist()
numbers = torch.split(batch.atomic_numbers, natoms)
fixed = torch.split(batch.fixed.to(torch.bool), natoms)
forces = torch.split(batch.force, natoms)
if results is not None:
results = {
key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist()
if len(val) == len(batch)
else [v.cpu().detach().numpy() for v in torch.split(val, natoms)]
for key, val in results.items()
}

positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
energies = batch.energy.view(-1).tolist()

atoms_objects = []
for idx in range(n_systems):
pos = positions[idx].cpu().detach().numpy()
cell = cells[idx].cpu().detach().numpy()

# TODO take pbc from data
if wrap_pos:
pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps)

atoms = Atoms(
numbers=numbers[idx].tolist(),
positions=positions[idx].cpu().detach().numpy(),
cell=cell,
positions=pos,
tags=tags[idx].tolist(),
cell=cells[idx].cpu().detach().numpy(),
constraint=FixAtoms(mask=fixed[idx].tolist()),
pbc=[True, True, True],
)
calc = sp(
atoms=atoms,
energy=energies[idx],
forces=forces[idx].cpu().detach().numpy(),
)
atoms.set_calculator(calc)

if results is not None:
calc = SinglePointCalculator(
atoms=atoms, **{key: val[idx] for key, val in results.items()}
)
atoms.set_calculator(calc)

atoms_objects.append(atoms)

return atoms_objects


class OCPCalculator(Calculator):
implemented_properties: ClassVar[list[str]] = ["energy", "forces"]
"""ASE based calculator using an OCP model"""

_reshaped_props = ASE_PROP_RESHAPE

def __init__(
self,
config_yml: str | None = None,
checkpoint_path: str | None = None,
checkpoint_path: str | Path | None = None,
model_name: str | None = None,
local_cache: str | None = None,
trainer: str | None = None,
cutoff: int = 6,
max_neighbors: int = 50,
cpu: bool = True,
seed: int | None = None,
) -> None:
Expand All @@ -96,16 +140,12 @@ def __init__(
Directory to save pretrained model checkpoints.
trainer (str):
OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE.
cutoff (int):
Cutoff radius to be used for data preprocessing.
max_neighbors (int):
Maximum amount of neighbors to store for a given atom.
cpu (bool):
Whether to load and run the model on CPU. Set `False` for GPU.
"""
setup_imports()
setup_logging()
Calculator.__init__(self)
super().__init__()

if model_name is not None:
if checkpoint_path is not None:
Expand Down Expand Up @@ -165,9 +205,8 @@ def __init__(
### backwards compatability with OCP v<2.0
config = update_config(config)

# Save config so obj can be transported over network (pkl)
self.config = copy.deepcopy(config)
self.config["checkpoint"] = checkpoint_path
self.config["checkpoint"] = str(checkpoint_path)
del config["dataset"]["src"]

self.trainer = registry.get_trainer_class(config["trainer"])(
Expand Down Expand Up @@ -199,14 +238,13 @@ def __init__(
self.trainer.set_seed(seed)

self.a2g = AtomsToGraphs(
max_neigh=max_neighbors,
radius=cutoff,
r_energy=False,
r_forces=False,
r_distances=False,
r_edges=False,
r_pbc=True,
r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model...
)
self.implemented_properties = list(self.config["outputs"].keys())

def load_checkpoint(
self, checkpoint_path: str, checkpoint: dict | None = None
Expand All @@ -217,6 +255,8 @@ def load_checkpoint(
Args:
checkpoint_path: string
Path to trained model
checkpoint: dict
A pretrained checkpoint dict
"""
try:
self.trainer.load_checkpoint(
Expand All @@ -225,14 +265,20 @@ def load_checkpoint(
except NotImplementedError:
logging.warning("Unable to load checkpoint!")

def calculate(self, atoms: Atoms, properties, system_changes) -> None:
Calculator.calculate(self, atoms, properties, system_changes)
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None:
"""Calculate implemented properties for a single Atoms object or a Batch of them."""
super().calculate(atoms, properties, system_changes)
if isinstance(atoms, Atoms):
data_object = self.a2g.convert(atoms)
batch = data_list_collater([data_object], otf_graph=True)
else:
batch = atoms

predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True)

for key in predictions:
_pred = predictions[key]
_pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy()
if key in OCPCalculator._reshaped_props:
_pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze()
self.results[key] = _pred
117 changes: 79 additions & 38 deletions src/fairchem/core/common/relaxation/ml_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,85 +10,126 @@
import logging
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING

import torch
from torch_geometric.data import Batch

from fairchem.core.common.typing import assert_is_instance
from fairchem.core.datasets.lmdb_dataset import data_list_collater

from .optimizers.lbfgs_torch import LBFGS, TorchCalc
from .optimizable import OptimizableBatch, OptimizableUnitCellBatch
from .optimizers.lbfgs_torch import LBFGS

if TYPE_CHECKING:
from fairchem.core.trainers import BaseTrainer


def ml_relax(
batch,
model,
batch: Batch,
model: BaseTrainer,
steps: int,
fmax: float,
relax_opt,
save_full_traj,
device: str = "cuda:0",
transform=None,
early_stop_batch: bool = False,
relax_opt: dict[str] | None = None,
relax_cell: bool = False,
relax_volume: bool = False,
save_full_traj: bool = True,
transform: torch.nn.Module | None = None,
mask_converged: bool = True,
):
"""
Runs ML-based relaxations.
"""Runs ML-based relaxations.
Args:
batch: object
model: object
steps: int
Max number of steps in the structure relaxation.
fmax: float
Structure relaxation terminates when the max force
of the system is no bigger than fmax.
relax_opt: str
Optimizer and corresponding parameters to be used for structure relaxations.
save_full_traj: bool
Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
batch: a data batch object.
model: a trainer object with model.
steps: Max number of steps in the structure relaxation.
fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax.
relax_opt: Optimizer parameters to be used for structure relaxations.
relax_cell: if true will use stress predictions to relax crystallographic cell.
The model given must predict stress
relax_volume: if true will relax the cell isotropically. the given model must predict stress.
save_full_traj: Whether to save out the full ASE trajectory. If False, only save out initial and final frames.
mask_converged: whether to mask batches where all atoms are below convergence threshold
cumulative_mask: if true, once system is masked then it remains masked even if new predictions give forces
above threshold, ie. once masked always masked. Note if this is used make sure to check convergence with
the same fmax always
"""
relax_opt = relax_opt or {}
# if not pbc is set, ignore it when comparing batches
if not hasattr(batch, "pbc"):
OptimizableBatch.ignored_changes = {"pbc"}

batches = deque([batch])
relaxed_batches = []
while batches:
batch = batches.popleft()
oom = False
ids = batch.sid
calc = TorchCalc(model, transform)

# clone the batch otherwise you can not run batch.to_data_list
# see https://github.com/pyg-team/pytorch_geometric/issues/8439#issuecomment-1826747915
if relax_cell or relax_volume:
optimizable = OptimizableUnitCellBatch(
batch.clone(),
trainer=model,
transform=transform,
mask_converged=mask_converged,
hydrostatic_strain=relax_volume,
)
else:
optimizable = OptimizableBatch(
batch.clone(),
trainer=model,
transform=transform,
mask_converged=mask_converged,
)

# Run ML-based relaxation
traj_dir = relax_opt.get("traj_dir", None)
traj_dir = relax_opt.get("traj_dir")
relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None})

optimizer = LBFGS(
batch,
calc,
maxstep=relax_opt.get("maxstep", 0.2),
memory=relax_opt["memory"],
damping=relax_opt.get("damping", 1.2),
alpha=relax_opt.get("alpha", 80.0),
device=device,
optimizable_batch=optimizable,
save_full_traj=save_full_traj,
traj_dir=Path(traj_dir) if traj_dir is not None else None,
traj_names=ids,
early_stop_batch=early_stop_batch,
**relax_opt,
)

e: RuntimeError | None = None
try:
relaxed_batch = optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(relaxed_batch)
optimizer.run(fmax=fmax, steps=steps)
relaxed_batches.append(optimizable.batch)
except RuntimeError as err:
e = err
oom = True
torch.cuda.empty_cache()

if oom:
# move OOM recovery code outside of except clause to allow tensors to be freed.
# move OOM recovery code outside off except clause to allow tensors to be freed.
data_list = batch.to_data_list()
if len(data_list) == 1:
raise assert_is_instance(e, RuntimeError)
logging.info(
f"Failed to relax batch with size: {len(data_list)}, splitting into two..."
)
mid = len(data_list) // 2
batches.appendleft(data_list_collater(data_list[:mid]))
batches.appendleft(data_list_collater(data_list[mid:]))
batches.appendleft(
data_list_collater(data_list[:mid], otf_graph=optimizable.otf_graph)
)
batches.appendleft(
data_list_collater(data_list[mid:], otf_graph=optimizable.otf_graph)
)

# reset for good measure
OptimizableBatch.ignored_changes = {}

relaxed_batch = Batch.from_data_list(relaxed_batches)

# Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str
# it will be incorrectly collated as a list of lists for each batch.
# but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above).
# So instead just manually fix it for now. Remove this once pyg dependency is removed
if isinstance(relaxed_batch.sid, list):
relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list]

return Batch.from_data_list(relaxed_batches)
return relaxed_batch
Loading

0 comments on commit 3f695ef

Please sign in to comment.