From da75f0ddbf32e2515bbb950481dae0b52e465e25 Mon Sep 17 00:00:00 2001 From: anuroopsriram Date: Mon, 25 Nov 2024 15:25:33 -0800 Subject: [PATCH 1/7] Update ODAC model_checkpoints.md (#893) --- docs/core/model_checkpoints.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/core/model_checkpoints.md b/docs/core/model_checkpoints.md index 6cfba7055..342498e14 100644 --- a/docs/core/model_checkpoints.md +++ b/docs/core/model_checkpoints.md @@ -149,7 +149,7 @@ OC22 dataset or pretrained models, as well as the original paper for each model: | GemNet-OC-S2EF-ODAC | GemNet-OC | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/gemnet-oc.yml) | | eSCN-S2EF-ODAC | eSCN | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eSCN.yml) | | EquiformerV2-S2EF-ODAC | EquiformerV2 | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/eqv2_31M.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eqv2_31M.yml) | -| EquiformerV2-Large-S2EF-ODAC | EquiformerV2 (Large) | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Large.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eqv2_153M.yml) | +| EquiformerV2-Large-S2EF-ODAC | EquiformerV2 (Large) | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/Equiformer_V2_Large.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/s2ef/eqv2_153M.yml) | ## IS2RE Direct models @@ -157,7 +157,7 @@ OC22 dataset or pretrained models, as well as the original paper for each model: |-------------------------|--------------|--- | --- | | Gemnet-OC-IS2RE-ODAC | Gemnet-OC | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Gemnet-OC_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/gemnet-oc.yml) | | eSCN-IS2RE-ODAC | eSCN | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/eSCN_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/eSCN.yml) | -| EquiformerV2-IS2RE-ODAC | EquiformerV2 | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231018/Equiformer_V2_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/eqv2_31M.yml) | +| EquiformerV2-IS2RE-ODAC | EquiformerV2 | [checkpoint](https://dl.fbaipublicfiles.com/dac/checkpoints_20231116/Equiformer_V2_Direct.pt) | [config](https://github.com/FAIR-Chem/fairchem/tree/main/configs/odac/is2re/eqv2_31M.yml) | The models in the table above were trained to predict relaxed energy directly. Relaxed energies can also be predicted by running structural relaxations using the S2EF models from the previous section. From 21af12fec6730bec8a1713119179827e899cdc51 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Mon, 25 Nov 2024 16:54:44 -0800 Subject: [PATCH 2/7] update omat24 configs for new loss (#927) * update omat24 configs for new loss * fix typo --- configs/omat24/all/eqV2_153M.yml | 2 +- configs/omat24/all/eqV2_31M.yml | 2 +- configs/omat24/all/eqV2_86M.yml | 2 +- configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml | 2 +- configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml | 2 +- configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml | 2 +- configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml | 2 +- configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml | 2 +- configs/omat24/mptrj/eqV2_31M_mptrj.yml | 2 +- configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/configs/omat24/all/eqV2_153M.yml b/configs/omat24/all/eqV2_153M.yml index dffd4ec34..bf87a9e3e 100644 --- a/configs/omat24/all/eqV2_153M.yml +++ b/configs/omat24/all/eqV2_153M.yml @@ -43,7 +43,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/all/eqV2_31M.yml b/configs/omat24/all/eqV2_31M.yml index 95ff1f89f..902a58b31 100644 --- a/configs/omat24/all/eqV2_31M.yml +++ b/configs/omat24/all/eqV2_31M.yml @@ -44,7 +44,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/all/eqV2_86M.yml b/configs/omat24/all/eqV2_86M.yml index 30167e81a..154ed57cb 100644 --- a/configs/omat24/all/eqV2_86M.yml +++ b/configs/omat24/all/eqV2_86M.yml @@ -43,7 +43,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml b/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml index d02b83760..bd04e683d 100644 --- a/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml +++ b/configs/omat24/finetune/eqV2_153M_ft_salexmptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 10 diff --git a/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml b/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml index 146a15312..36e89c66b 100644 --- a/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml +++ b/configs/omat24/finetune/eqV2_31M_ft_salexmptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 10 diff --git a/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml b/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml index 8976ffa9a..8e230aa1a 100644 --- a/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml +++ b/configs/omat24/finetune/eqV2_86M_ft_salexmptrj.yml @@ -43,7 +43,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 10 diff --git a/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml index 050d5921d..435309a22 100644 --- a/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_153M_dens_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml index 818eaeb09..cf346824a 100644 --- a/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/mptrj/eqV2_31M_mptrj.yml b/configs/omat24/mptrj/eqV2_31M_mptrj.yml index c9ae7c84d..7f4c83cf6 100644 --- a/configs/omat24/mptrj/eqV2_31M_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_31M_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 5 - forces: fn: l2mae coefficient: 20 diff --git a/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml index f931ee78a..47f095885 100644 --- a/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_86M_dens_mptrj.yml @@ -45,7 +45,7 @@ outputs: loss_functions: - energy: fn: per_atom_mae - coefficient: 20 + coefficient: 2.5 - forces: fn: l2mae coefficient: 20 From 72614bf0098224052b43189260aeed22e1be552f Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Mon, 25 Nov 2024 17:22:10 -0800 Subject: [PATCH 3/7] Add wandb logger init to hydra runners (#894) * add wandb logger init to hydra runners * update to reading dict vars * update to reading dict vars * get rid of finally clause * move logger init * add deprecation comment * Revert "add deprecation comment" This reverts commit f9760e70e05b082e0d2b90344f196f889f9b1424. --- src/fairchem/core/_cli_hydra.py | 59 +++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 6ca8b7d6f..9279c8daf 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING import hydra +from omegaconf import OmegaConf if TYPE_CHECKING: import argparse @@ -34,21 +35,40 @@ class Submitit(Checkpointable): - def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None: + def __call__(self, dict_config: DictConfig) -> None: self.config = dict_config - self.cli_args = cli_args # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() - try: - distutils.setup(map_cli_args_to_dist_config(cli_args)) - self.runner: Runner = hydra.utils.instantiate(dict_config.runner) - self.runner.load_state() - self.runner.run() - finally: - distutils.cleanup() - - def checkpoint(self, *args, **kwargs): + distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) + self._init_logger() + runner: Runner = hydra.utils.instantiate(dict_config.runner) + runner.load_state() + runner.run() + distutils.cleanup() + + def _init_logger(self) -> None: + # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger + # don't start logger if in debug mode + if ( + "logger" in self.config + and distutils.is_master() + and not self.config.cli_args.debug + ): + # get a partial function from the config and instantiate wandb with it + logger_initializer = hydra.utils.instantiate(self.config.logger) + simple_config = OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=True + ) + logger_initializer( + config=simple_config, + run_id=self.config.cli_args.timestamp_id, + run_name=self.config.cli_args.identifier, + log_dir=self.config.cli_args.logdir, + ) + + def checkpoint(self, *args, **kwargs) -> DelayedSubmission: + # TODO: this is yet to be tested properly logging.info("Submitit checkpointing callback is triggered") new_runner = Submitit() self.runner.save_state() @@ -56,7 +76,7 @@ def checkpoint(self, *args, **kwargs): return DelayedSubmission(new_runner, self.config, self.cli_args) -def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: +def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict: return { "world_size": cli_args.num_nodes * cli_args.num_gpus, "distributed_backend": "gloo" if cli_args.cpu else "nccl", @@ -78,8 +98,8 @@ def get_hydra_config_from_yaml( return hydra.compose(config_name=config_name, overrides=overrides_args) -def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace): - Submitit()(config, cli_args) +def runner_wrapper(config: DictConfig): + Submitit()(config) # this is meant as a future replacement for the main entrypoint @@ -93,6 +113,11 @@ def main( cfg = get_hydra_config_from_yaml(args.config_yml, override_args) timestamp_id = get_timestamp_uid() log_dir = os.path.join(args.run_dir, timestamp_id, "logs") + # override timestamp id and logdir + args.timestamp_id = timestamp_id + args.logdir = log_dir + os.makedirs(log_dir) + OmegaConf.update(cfg, "cli_args", vars(args), force_add=True) if args.submit: # Run on cluster executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) executor.update_parameters( @@ -107,7 +132,7 @@ def main( slurm_qos=args.slurm_qos, slurm_account=args.slurm_account, ) - job = executor.submit(runner_wrapper, cfg, args) + job = executor.submit(runner_wrapper, cfg) logger.info( f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}" ) @@ -131,8 +156,8 @@ def main( rdzv_backend="c10d", max_restarts=0, ) - elastic_launch(launch_config, runner_wrapper)(cfg, args) + elastic_launch(launch_config, runner_wrapper)(cfg) else: logger.info("Running in local mode without elastic launch") distutils.setup_env_local() - runner_wrapper(cfg, args) + runner_wrapper(cfg) From e11e78e87f791241b3759a7218a1b0e307a49662 Mon Sep 17 00:00:00 2001 From: Xiang <31351668+kyonofx@users.noreply.github.com> Date: Tue, 26 Nov 2024 15:13:10 -0800 Subject: [PATCH 4/7] EquiformerV2 + DeNS model and trainer (#880) * add density metrics * update trainer & loss * interleave atoms in loss * fix call to keys * add rmse to evaluation metrics * fix linting. * per_atom_loss fix * fix test * Equiformer DeNS model and trainer * fix linting. * lint * lint again * add type hints * empty cuda cache and remove db closing * type hints * add missing args to docstring * add return type hints * rename dens heads * move use_densoising to heads * abstract denoising targets * update omat24 dens config * fix imports * fix trainer --------- Co-authored-by: lbluque Co-authored-by: Luis Barroso-Luque Co-authored-by: Brandon --- configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml | 9 +- .../equiformer_v2/equiformer_v2_dens.py | 586 ++++++++++++ .../equiformer_v2/trainers/dens_trainer.py | 854 ++++++++++++++++++ src/fairchem/core/trainers/ocp_trainer.py | 14 +- 4 files changed, 1456 insertions(+), 7 deletions(-) create mode 100644 src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py create mode 100644 src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py diff --git a/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml b/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml index cf346824a..78ccb4b92 100644 --- a/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml +++ b/configs/omat24/mptrj/eqV2_31M_dens_mptrj.yml @@ -148,17 +148,16 @@ model: use_force_encoding: True use_noise_schedule_sigma_encoding: False - use_denoising_energy: True - use_denoising_stress: False - heads: energy: - module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSEnergyHead + module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSScalarHead + use_denoising: True forces: - module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSForceHead + module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSVectorHead stress: module: fairchem.core.models.equiformer_v2.equiformer_v2_dens.DeNSRank2Head output_name: stress use_source_target_embedding: True decompose: True + use_denoising: False diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py new file mode 100644 index 000000000..3ad881d63 --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2_dens.py @@ -0,0 +1,586 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import math +from functools import partial +from typing import TYPE_CHECKING, Literal + +import torch + +from fairchem.core.common import gp_utils +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import conditional_grad + +try: + from e3nn import o3 +except ImportError: + import contextlib + + contextlib.suppress(ImportError) + +from fairchem.core.models.base import GraphData, HeadInterface +from fairchem.core.models.equiformer_v2.equiformer_v2 import ( + EquiformerV2Backbone, + eqv2_init_weights, +) +from fairchem.core.models.equiformer_v2.heads.rank2 import ( + Rank2SymmetricTensorHead, +) +from fairchem.core.models.equiformer_v2.so3 import SO3_Embedding, SO3_LinearV2 +from fairchem.core.models.equiformer_v2.transformer_block import ( + FeedForwardNetwork, + SO2EquivariantGraphAttention, +) + +if TYPE_CHECKING: + from torch_geometric.data.batch import Batch + + from fairchem.core.models.base import BackboneInterface + + +@registry.register_model("equiformer_v2_dens_backbone") +class EqV2DeNSBackbone(EquiformerV2Backbone): + """ + DeNS extra Args: + use_force_encoding (bool): For ablation study, whether to encode forces during denoising positions. Default: True. + use_noise_schedule_sigma_encoding (bool): For ablation study, whether to encode the sigma (sampled std of Gaussian noises) during + denoising positions when `fixed_noise_std` = False in config files. Default: False. + """ + + def __init__( + self, + use_pbc: bool = True, + use_pbc_single: bool = False, + regress_forces: bool = True, + otf_graph: bool = True, + max_neighbors: int = 500, + max_radius: float = 5.0, + max_num_elements: int = 90, + num_layers: int = 12, + sphere_channels: int = 128, + attn_hidden_channels: int = 128, + num_heads: int = 8, + attn_alpha_channels: int = 32, + attn_value_channels: int = 16, + ffn_hidden_channels: int = 512, + norm_type: str = "rms_norm_sh", + lmax_list: list[int] | None = None, + mmax_list: list[int] | None = None, + grid_resolution: int | None = None, + num_sphere_samples: int = 128, + edge_channels: int = 128, + use_atom_edge_embedding: bool = True, + share_atom_edge_embedding: bool = False, + use_m_share_rad: bool = False, + distance_function: str = "gaussian", + num_distance_basis: int = 512, + attn_activation: str = "scaled_silu", + use_s2_act_attn: bool = False, + use_attn_renorm: bool = True, + ffn_activation: str = "scaled_silu", + use_gate_act: bool = False, + use_grid_mlp: bool = False, + use_sep_s2_act: bool = True, + alpha_drop: float = 0.1, + drop_path_rate: float = 0.05, + proj_drop: float = 0.0, + weight_init: str = "normal", + enforce_max_neighbors_strictly: bool = True, + avg_num_nodes: float | None = None, + avg_degree: float | None = None, + use_energy_lin_ref: bool | None = False, + load_energy_lin_ref: bool | None = False, + activation_checkpoint: bool | None = False, + use_force_encoding=True, + use_noise_schedule_sigma_encoding: bool = False, + ): + if mmax_list is None: + mmax_list = [2] + if lmax_list is None: + lmax_list = [6] + super().__init__( + use_pbc, + use_pbc_single, + regress_forces, + otf_graph, + max_neighbors, + max_radius, + max_num_elements, + num_layers, + sphere_channels, + attn_hidden_channels, + num_heads, + attn_alpha_channels, + attn_value_channels, + ffn_hidden_channels, + norm_type, + lmax_list, + mmax_list, + grid_resolution, + num_sphere_samples, + edge_channels, + use_atom_edge_embedding, + share_atom_edge_embedding, + use_m_share_rad, + distance_function, + num_distance_basis, + attn_activation, + use_s2_act_attn, + use_attn_renorm, + ffn_activation, + use_gate_act, + use_grid_mlp, + use_sep_s2_act, + alpha_drop, + drop_path_rate, + proj_drop, + weight_init, + enforce_max_neighbors_strictly, + avg_num_nodes, + avg_degree, + use_energy_lin_ref, + load_energy_lin_ref, + activation_checkpoint, + ) + + # for denoising position + self.use_force_encoding = use_force_encoding + self.use_noise_schedule_sigma_encoding = use_noise_schedule_sigma_encoding + + # for denoising position, encode node-wise forces as node features + self.irreps_sh = o3.Irreps.spherical_harmonics(lmax=max(self.lmax_list), p=1) + self.force_embedding = SO3_LinearV2( + in_features=1, out_features=self.sphere_channels, lmax=max(self.lmax_list) + ) + + if self.use_noise_schedule_sigma_encoding: + self.noise_schedule_sigma_embedding = torch.nn.Linear( + in_features=1, out_features=self.sphere_channels + ) + + self.apply(partial(eqv2_init_weights, weight_init=self.weight_init)) + + @conditional_grad(torch.enable_grad()) + def forward(self, data) -> dict[str, torch.Tensor]: + self.batch_size = len(data.natoms) + self.dtype = data.pos.dtype + self.device = data.pos.device + num_atoms = len(data.atomic_numbers) + atomic_numbers = data.atomic_numbers.long() + graph = self.generate_graph( + data, + enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + ) + + data_batch = data.batch + if gp_utils.initialized(): + ( + atomic_numbers, + data_batch, + node_offset, + edge_index, + edge_distance, + edge_distance_vec, + ) = self._init_gp_partitions( + graph.atomic_numbers_full, + graph.batch_full, + graph.edge_index, + graph.edge_distance, + graph.edge_distance_vec, + ) + graph.node_offset = node_offset + graph.edge_index = edge_index + graph.edge_distance = edge_distance + graph.edge_distance_vec = edge_distance_vec + + ############################################################### + # Entering Graph Parallel Region + # after this point, if using gp, then node, edge tensors are split + # across the graph parallel ranks, some full tensors such as + # atomic_numbers_full are required because we need to index into the + # full graph when computing edge embeddings or reducing nodes from neighbors + # + # all tensors that do not have the suffix "_full" refer to the partial tensors. + # if not using gp, the full values are equal to the partial values + # ie: atomic_numbers_full == atomic_numbers + ############################################################### + + ############################################################### + # Initialize data structures + ############################################################### + + # Compute 3x3 rotation matrix per edge + edge_rot_mat = self._init_edge_rot_mat( + data, graph.edge_index, graph.edge_distance_vec + ) + + # Initialize the WignerD matrices and other values for spherical harmonic calculations + for i in range(self.num_resolutions): + self.SO3_rotation[i].set_wigner(edge_rot_mat) + + ############################################################### + # Initialize node embeddings + ############################################################### + + # Init per node representations using an atomic number based embedding + x = SO3_Embedding( + len(atomic_numbers), + self.lmax_list, + self.sphere_channels, + self.device, + self.dtype, + ) + + offset_res = 0 + offset = 0 + # Initialize the l = 0, m = 0 coefficients for each resolution + for i in range(self.num_resolutions): + if self.num_resolutions == 1: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers) + else: + x.embedding[:, offset_res, :] = self.sphere_embedding(atomic_numbers)[ + :, offset : offset + self.sphere_channels + ] + offset = offset + self.sphere_channels + offset_res = offset_res + int((self.lmax_list[i] + 1) ** 2) + + ################## + ### DeNS Start ### + ################## + + # Node-wise force encoding during denoising positions + force_embedding = SO3_Embedding( + num_atoms, self.lmax_list, 1, self.device, self.dtype + ) + if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward: + assert hasattr(data, "forces") + force_data = data.forces + force_sh = o3.spherical_harmonics( + l=self.irreps_sh, + x=force_data, + normalize=True, + normalization="component", + ) + force_sh = force_sh.view(num_atoms, (max(self.lmax_list) + 1) ** 2, 1) + force_norm = force_data.norm(dim=-1, keepdim=True) + if hasattr(data, "noise_mask"): + noise_mask_tensor = data.noise_mask.view(-1, 1, 1) + force_sh = force_sh * noise_mask_tensor + else: + force_sh = torch.zeros( + (num_atoms, (max(self.lmax_list) + 1) ** 2, 1), + dtype=data.pos.dtype, + device=data.pos.device, + ) + force_norm = torch.zeros( + (num_atoms, 1), dtype=data.pos.dtype, device=data.pos.device + ) + + if not self.use_force_encoding: + # for ablation study, we enforce the force encoding to be zero. + force_sh = torch.zeros( + (num_atoms, (max(self.lmax_list) + 1) ** 2, 1), + dtype=data.pos.dtype, + device=data.pos.device, + ) + force_norm = torch.zeros( + (num_atoms, 1), dtype=data.pos.dtype, device=data.pos.device + ) + + force_norm = force_norm.view(-1, 1, 1) + force_norm = force_norm / math.sqrt( + 3.0 + ) # since we use `component` normalization + force_embedding.embedding = force_sh * force_norm + + force_embedding = self.force_embedding(force_embedding) + x.embedding = x.embedding + force_embedding.embedding + + # noise schedule sigma encoding + if self.use_noise_schedule_sigma_encoding: + if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward: + assert hasattr(data, "sigmas") + sigmas = data.sigmas + else: + sigmas = torch.zeros( + (num_atoms, 1), dtype=data.pos.dtype, device=data.pos.device + ) + noise_schedule_sigma_enbedding = self.noise_schedule_sigma_embedding(sigmas) + x.embedding[:, 0, :] = x.embedding[:, 0, :] + noise_schedule_sigma_enbedding + + ################## + ### DeNS End ### + ################## + + # Edge encoding (distance and atom edge) + graph.edge_distance = self.distance_expansion(graph.edge_distance) + if self.share_atom_edge_embedding and self.use_atom_edge_embedding: + source_element = graph.atomic_numbers_full[ + graph.edge_index[0] + ] # Source atom atomic number + target_element = graph.atomic_numbers_full[ + graph.edge_index[1] + ] # Target atom atomic number + source_embedding = self.source_embedding(source_element) + target_embedding = self.target_embedding(target_element) + graph.edge_distance = torch.cat( + (graph.edge_distance, source_embedding, target_embedding), dim=1 + ) + + # Edge-degree embedding + edge_degree = self.edge_degree_embedding( + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + len(atomic_numbers), + graph.node_offset, + ) + x.embedding = x.embedding + edge_degree.embedding + + ############################################################### + # Update spherical node embeddings + ############################################################### + + for i in range(self.num_layers): + if self.activation_checkpoint: + x = torch.utils.checkpoint.checkpoint( + self.blocks[i], + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + data_batch, # for GraphDropPath + graph.node_offset, + use_reentrant=not self.training, + ) + else: + x = self.blocks[i]( + x, # SO3_Embedding + graph.atomic_numbers_full, + graph.edge_distance, + graph.edge_index, + batch=data_batch, # for GraphDropPath + node_offset=graph.node_offset, + ) + + # Final layer norm + x.embedding = self.norm(x.embedding) + + return {"node_embedding": x, "graph": graph} + + +@registry.register_model("eqV2_DeNS_scalar_head") +class DeNSScalarHead(torch.nn.Module, HeadInterface): + def __init__( + self, + backbone: BackboneInterface, + output_name: str = "energy", + reduce: Literal["sum", "mean"] = "sum", + use_denoising: bool = True, + ): + """ + Args: + backbone: Model backbone + output_name: property output name + reduce: reduction, mean or sum. Use mean for intensive properties and sum for extensive ones. + use_denoising: For ablation study, whether to predict the energy of the original structure given + a corrupted structure. If `False`, we zero out the energy prediction. Default: True. + """ + super().__init__() + self.reduce = reduce + self.avg_num_nodes = backbone.avg_num_nodes + self.scalar_block = FeedForwardNetwork( + backbone.sphere_channels, + backbone.ffn_hidden_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_grid, + backbone.ffn_activation, + backbone.use_gate_act, + backbone.use_grid_mlp, + backbone.use_sep_s2_act, + ) + self.output_name = output_name + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) + self.use_denoising = use_denoising + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor | GraphData] + ) -> dict[str, torch.Tensor]: + node_out = self.scalar_block(emb["node_embedding"]) + node_out = node_out.embedding.narrow(1, 0, 1) + if gp_utils.initialized(): + node_out = gp_utils.gather_from_model_parallel_region(node_out, dim=0) + output_scalar = torch.zeros( + len(data.natoms), + device=node_out.device, + dtype=node_out.dtype, + ) + + output_scalar.index_add_(0, data.batch, node_out.view(-1)) + + if ( + hasattr(data, "denoising_pos_forward") + and data.denoising_pos_forward + and not self.use_denoising + ): + output_scalar = output_scalar * 0.0 + + if self.reduce == "sum": + return {self.output_name: output_scalar / self.avg_num_nodes} + elif self.reduce == "mean": + return {self.output_name: output_scalar / data.natoms} + else: + raise ValueError( + f"reduce can only be sum or mean, user provided: {self.reduce}" + ) + + +@registry.register_model("eqV2_DeNS_vector_head") +class DeNSVectorHead(torch.nn.Module, HeadInterface): + def __init__(self, backbone: BackboneInterface, output_name: str = "forces"): + super().__init__() + + self.output_name = output_name + self.activation_checkpoint = backbone.activation_checkpoint + + self.vector_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + + self.denoising_pos_block = SO2EquivariantGraphAttention( + backbone.sphere_channels, + backbone.attn_hidden_channels, + backbone.num_heads, + backbone.attn_alpha_channels, + backbone.attn_value_channels, + 1, + backbone.lmax_list, + backbone.mmax_list, + backbone.SO3_rotation, + backbone.mappingReduced, + backbone.SO3_grid, + backbone.max_num_elements, + backbone.edge_channels_list, + backbone.block_use_atom_edge_embedding, + backbone.use_m_share_rad, + backbone.attn_activation, + backbone.use_s2_act_attn, + backbone.use_attn_renorm, + backbone.use_gate_act, + backbone.use_sep_s2_act, + alpha_drop=0.0, + ) + self.apply(partial(eqv2_init_weights, weight_init=backbone.weight_init)) + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + if self.activation_checkpoint: + output_vector = torch.utils.checkpoint.checkpoint( + self.vector_block, + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + emb["graph"].node_offset, + use_reentrant=not self.training, + ) + denoising_pos_vec = torch.utils.checkpoint.checkpoint( + self.denoising_pos_block, + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + emb["graph"].node_offset, + use_reentrant=not self.training, + ) + else: + output_vector = self.vector_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + denoising_pos_vec = self.denoising_pos_block( + emb["node_embedding"], + emb["graph"].atomic_numbers_full, + emb["graph"].edge_distance, + emb["graph"].edge_index, + node_offset=emb["graph"].node_offset, + ) + output_vector = output_vector.embedding.narrow(1, 1, 3) + output_vector = output_vector.view(-1, 3).contiguous() + denoising_pos_vec = denoising_pos_vec.embedding.narrow(1, 1, 3) + denoising_pos_vec = denoising_pos_vec.view(-1, 3) + if gp_utils.initialized(): + output_vector = gp_utils.gather_from_model_parallel_region( + output_vector, dim=0 + ) + denoising_pos_vec = gp_utils.gather_from_model_parallel_region( + denoising_pos_vec, dim=0 + ) + + if hasattr(data, "denoising_pos_forward") and data.denoising_pos_forward: + if hasattr(data, "noise_mask"): + noise_mask_tensor = data.noise_mask.view(-1, 1) + output_vector = ( + denoising_pos_vec * noise_mask_tensor + + output_vector * (~noise_mask_tensor) + ) + else: + output_vector = denoising_pos_vec + 0 * output_vector + else: + output_vector = 0 * denoising_pos_vec + output_vector + + return {self.output_name: output_vector} + + +@registry.register_model("dens_rank2_symmetric_head") +class DeNSRank2Head(Rank2SymmetricTensorHead): + def __init__( + self, backbone: BackboneInterface, *args, use_denoising: bool = True, **kwargs + ): + super().__init__(backbone, *args, **kwargs) + self.use_denoising = use_denoising + + def forward( + self, data: Batch, emb: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + output = super().forward(data, emb) + if ( + hasattr(data, "denoising_pos_forward") + and data.denoising_pos_forward + and not self.use_denoising + ): + for k in output: + output[k] = output[k] * 0.0 + return output diff --git a/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py b/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py new file mode 100644 index 000000000..11735d7bb --- /dev/null +++ b/src/fairchem/core/models/equiformer_v2/trainers/dens_trainer.py @@ -0,0 +1,854 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from dataclasses import dataclass +from functools import cached_property +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import torch_geometric +from tqdm import tqdm + +from fairchem.core.common import distutils +from fairchem.core.common.registry import registry +from fairchem.core.modules.evaluator import mae +from fairchem.core.modules.normalization.normalizer import Normalizer +from fairchem.core.modules.scaling.util import ensure_fitted + +from .forces_trainer import EquiformerV2ForcesTrainer + +if TYPE_CHECKING: + from fairchem.core.modules.evaluator import Evaluator + + +@dataclass +class DenoisingPosParams: + prob: float = 0.0 + fixed_noise_std: bool = False + std: float = None + num_steps: int = None + std_low: float = None + std_high: float = None + corrupt_ratio: float = None + all_atoms: bool = False + denoising_pos_coefficient: float = None + + +def add_gaussian_noise_to_position(batch, std, corrupt_ratio=None, all_atoms=False): + """ + 1. Update `pos` in `batch`. + 2. Add `noise_vec` to `batch`, which will serve as the target for denoising positions. + 3. Add `denoising_pos_forward` to switch to denoising mode during training. + 4. Add `noise_mask` for partially corrupted structures when `corrupt_ratio` is not None. + 5. If `all_atoms` == True, we add noise to all atoms including fixed ones. + 6. Check whether `batch` has `md`. We do not add noise to structures from MD split. + """ + noise_vec = torch.zeros_like(batch.pos) + noise_vec = noise_vec.normal_(mean=0.0, std=std) + + if corrupt_ratio is not None: + noise_mask = torch.rand( + (batch.pos.shape[0]), + dtype=batch.pos.dtype, + device=batch.pos.device, + ) + noise_mask = noise_mask < corrupt_ratio + noise_vec[(~noise_mask)] *= 0 + batch.noise_mask = noise_mask + + # Not add noise to structures from MD split + if hasattr(batch, "md"): + batch_index = batch.batch + md_index = batch.md.bool() + md_index = md_index[batch_index] + noise_mask = ~md_index + noise_vec[(~noise_mask)] *= 0 + if hasattr(batch, "noise_mask"): + batch.noise_mask = batch.noise_mask * noise_mask + else: + batch.noise_mask = noise_mask + + pos = batch.pos + new_pos = pos + noise_vec + if all_atoms: + batch.pos = new_pos + else: + free_mask = batch.fixed == 0.0 + batch.pos[free_mask] = new_pos[free_mask] + + batch.noise_vec = noise_vec + batch.denoising_pos_forward = True + + return batch + + +def add_gaussian_noise_schedule_to_position( + batch, std_low, std_high, num_steps, corrupt_ratio=None, all_atoms=False +): + """ + 1. Similar to above, update positions in batch with gaussian noise, but + additionally, also save the sigmas the noise vectors are sampled from. + 2. Add `noise_mask` for partially corrupted structures when `corrupt_ratio` + is not None. + 3. If `all_atoms` == True, we add noise to all atoms including fixed ones. + 4. Check whether `batch` has `md`. We do not add noise to structures from MD split. + """ + sigmas = torch.tensor( + np.exp(np.linspace(np.log(std_low), np.log(std_high), num_steps)), + dtype=torch.float32, + ) + # select a sigma for each structure, and project it all atoms in the structure. + ts = torch.randint(0, num_steps, size=(batch.natoms.size(0),)) + batch.sigmas = sigmas[ts][batch.batch][:, None] # (natoms, 1) + noise_vec = torch.zeros_like(batch.pos) + noise_vec = noise_vec.normal_() * batch.sigmas + + if corrupt_ratio is not None: + noise_mask = torch.rand( + (batch.pos.shape[0]), + dtype=batch.pos.dtype, + device=batch.pos.device, + ) + noise_mask = noise_mask < corrupt_ratio + # noise_vec[(~noise_mask)] *= 0 + batch.noise_mask = noise_mask + + # Not add noise to structures from MD split + if hasattr(batch, "md"): + batch_index = batch.batch + md_index = batch.md.bool() + md_index = md_index[batch_index] + noise_mask = ~md_index + # noise_vec[(~noise_mask)] *= 0 + if hasattr(batch, "noise_mask"): + batch.noise_mask = batch.noise_mask * noise_mask + else: + batch.noise_mask = noise_mask + + if hasattr(batch, "noise_mask"): + noise_vec[(~batch.noise_mask)] *= 0 + + # only add noise to free atoms + pos = batch.pos + new_pos = pos + noise_vec + if all_atoms: + batch.pos = new_pos + else: + free_mask = batch.fixed == 0.0 + batch.pos[free_mask] = new_pos[free_mask] + + batch.noise_vec = noise_vec + batch.denoising_pos_forward = True + + return batch + + +def denoising_pos_eval( + evaluator: Evaluator, + prediction: dict[str, torch.Tensor], + target: dict[str, torch.Tensor], + denoising_targets: tuple[str], + prev_metrics: dict[str, torch.Tensor] | None = None, + denoising_pos_forward: bool = False, +): + """ + 1. Overwrite the original Evaluator.eval() here: https://github.com/Open-Catalyst-Project/ocp/blob/5a7738f9aa80b1a9a7e0ca15e33938b4d2557edd/ocpmodels/modules/evaluator.py#L69-L81 + 2. This is to make sure we separate forces MAE and denoising positions MAE. + """ + + if not denoising_pos_forward: + return evaluator.eval(prediction, target, prev_metrics) + + metrics = prev_metrics + for target_name in denoising_targets: + res = mae(prediction, target, target_name) + metrics = evaluator.update(f"denoising_{target_name}_mae", res, metrics) + + if target.get("noise_mask") is None: + # Only update`denoising_pos_mae` during denoising positions if not using partially corrupted structures + res = mae(prediction, target, "forces") + metrics = evaluator.update("denoising_pos_mae", res, metrics) + else: # Update `denoising_pos_mae` and `denoising_force_mae` if using partially corrupted structures + # separate S2EF and denoising positions results based on `noise_mask` + target_tensor = target["forces"] + prediction_tensor = prediction["forces"] + noise_mask = target["noise_mask"] + s2ef_index = torch.where(noise_mask == 0) + s2ef_prediction = {"forces": prediction_tensor[s2ef_index]} + s2ef_target = {"forces": target_tensor[s2ef_index]} + res = mae(s2ef_prediction, s2ef_target, "forces") + if res["numel"] != 0: + metrics = evaluator.update("denoising_force_mae", res, metrics) + denoising_pos_index = torch.where(noise_mask == 1) + denoising_pos_prediction = {"forces": prediction_tensor[denoising_pos_index]} + denoising_pos_target = {"forces": target_tensor[denoising_pos_index]} + res = mae(denoising_pos_prediction, denoising_pos_target, "forces") + if res["numel"] != 0: + metrics = evaluator.update("denoising_pos_mae", res, metrics) + return metrics + + +def compute_atomwise_denoising_pos_and_force_hybrid_loss( + pred, target, noise_mask, force_mult, denoising_pos_mult, mask=None +): + loss = torch.norm(pred - target, p=2, dim=-1, keepdim=True) + force_index = torch.where(noise_mask == 0) + denoising_pos_index = torch.where(noise_mask == 1) + mult_tensor = torch.ones_like(loss) + mult_tensor[force_index] *= force_mult + mult_tensor[denoising_pos_index] *= denoising_pos_mult + loss = loss * mult_tensor + if mask is not None: + loss = loss[mask] + return torch.mean(loss) + + +@registry.register_trainer("equiformerv2_dens") +class DenoisingForcesTrainer(EquiformerV2ForcesTrainer): + """ + 1. We add a denoising objective to the original S2EF task. + 2. The denoising objective is that we take as input + atom types, node-wise forces and 3D coordinates perturbed with Gaussian noises and then + output energy of the original structure (3D coordinates without any perturbation) and + the node-wise noises added to the original structure. + 3. This should make models leverage more from training data and enable data augmentation for + the S2EF task. + 4. We should only modify the training part. + 5. For normalizing the outputs of noise prediction, if we use `fixed_noise_std = True`, we use + `std` for the normalization factor. Otherwise, we use `std_high` when `fixed_noise_std = False`. + + Args: + task (dict): Task configuration. + model (dict): Model configuration. + outputs (dict): Dictionary of model output configuration. + dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. + optimizer (dict): Optimizer configuration. + loss_functions (dict): Loss function configuration. + evaluation_metrics (dict): Evaluation metrics configuration. + identifier (str): Experiment identifier that is appended to log directory. + run_dir (str, optional): Path to the run directory where logs are to be saved. + (default: :obj:`None`) + timestamp_id (str, optional): timestamp identifier. + run_dir (str, optional): Run directory used to save checkpoints and results. + is_debug (bool, optional): Run in debug mode. + (default: :obj:`False`) + print_every (int, optional): Frequency of printing logs. + (default: :obj:`100`) + seed (int, optional): Random number seed. + (default: :obj:`None`) + logger (str, optional): Type of logger to be used. + (default: :obj:`wandb`) + local_rank (int, optional): Local rank of the process, only applicable for distributed training. + (default: :obj:`0`) + amp (bool, optional): Run using automatic mixed precision. + (default: :obj:`False`) + cpu (bool): If True will run on CPU. Default is False, will attempt to use cuda. + name (str): Trainer name. + slurm (dict): Slurm configuration. Currently just for keeping track. + (default: :obj:`{}`) + gp_gpus (int, optional): Number of graph parallel GPUs. + inference_only (bool): If true trainer will be loaded for inference only. + (ie datasets, optimizer, schedular, etc, will not be instantiated) + """ + + def __init__( + self, + task: dict[str, str | Any], + model: dict[str, Any], + outputs: dict[str, str | int], + dataset: dict[str, str | float], + optimizer: dict[str, str | float], + loss_functions: dict[str, str | float], + evaluation_metrics: dict[str, str], + identifier: str, + # TODO: dealing with local rank is dangerous + # T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use + local_rank: int, + timestamp_id: str | None = None, + run_dir: str | None = None, + is_debug: bool = False, + print_every: int = 100, + seed: int | None = None, + logger: str = "wandb", + amp: bool = False, + cpu: bool = False, + name: str = "ocp", + slurm: dict | None = None, + gp_gpus: int | None = None, + inference_only: bool = False, + ): + if slurm is None: + slurm = {} + super().__init__( + task=task, + model=model, + outputs=outputs, + dataset=dataset, + optimizer=optimizer, + loss_functions=loss_functions, + evaluation_metrics=evaluation_metrics, + identifier=identifier, + timestamp_id=timestamp_id, + run_dir=run_dir, + is_debug=is_debug, + print_every=print_every, + seed=seed, + logger=logger, + local_rank=local_rank, + amp=amp, + cpu=cpu, + slurm=slurm, + name=name, + gp_gpus=gp_gpus, + inference_only=inference_only, + ) + + # for denoising positions + self.use_denoising_pos = self.config["optim"]["use_denoising_pos"] + self.denoising_pos_params = DenoisingPosParams( + **self.config["optim"]["denoising_pos_params"] + ) + self.denoising_pos_params.denoising_pos_coefficient = self.config["optim"][ + "denoising_pos_coefficient" + ] + self.normalizers["denoising_pos_target"] = Normalizer( + mean=0.0, + rmsd=( + self.denoising_pos_params.std + if self.denoising_pos_params.fixed_noise_std + else self.denoising_pos_params.std_high + ), + ) + self.normalizers["denoising_pos_target"].to(self.device) + + @cached_property + def denoising_targets(self): + return tuple( + head.output_name + for head in self._unwrapped_model.output_heads.values() + if getattr(head, "use_denoising", False) + ) + + def train(self, disable_eval_tqdm=False): + ensure_fitted(self._unwrapped_model, warn=True) + + eval_every = self.config["optim"].get("eval_every", len(self.train_loader)) + checkpoint_every = self.config["optim"].get("checkpoint_every", eval_every) + primary_metric = self.evaluation_metrics.get( + "primary_metric", self.evaluator.task_primary_metric[self.name] + ) + if not hasattr(self, "primary_metric") or self.primary_metric != primary_metric: + self.best_val_metric = 1e9 if "mae" in primary_metric else -1.0 + else: + primary_metric = self.primary_metric + self.metrics = {} + + # Calculate start_epoch from step instead of loading the epoch number + # to prevent inconsistencies due to different batch size in checkpoint. + start_epoch = self.step // len(self.train_loader) + + for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): + skip_steps = self.step % len(self.train_loader) + self.train_sampler.set_epoch_and_start_iteration(epoch_int, skip_steps) + train_loader_iter = iter(self.train_loader) + + for i in range(skip_steps, len(self.train_loader)): + self.epoch = epoch_int + (i + 1) / len(self.train_loader) + self.step = epoch_int * len(self.train_loader) + i + 1 + self.model.train() + + # Get a batch. + batch = next(train_loader_iter) + + # for denoising positions + if ( + self.use_denoising_pos + and np.random.rand() < self.denoising_pos_params.prob + ): + if self.denoising_pos_params.fixed_noise_std: + batch = add_gaussian_noise_to_position( + batch, + std=self.denoising_pos_params.std, + corrupt_ratio=self.denoising_pos_params.corrupt_ratio, + all_atoms=self.denoising_pos_params.all_atoms, + ) + else: + batch = add_gaussian_noise_schedule_to_position( + batch, + std_low=self.denoising_pos_params.std_low, + std_high=self.denoising_pos_params.std_high, + num_steps=self.denoising_pos_params.num_steps, + corrupt_ratio=self.denoising_pos_params.corrupt_ratio, + all_atoms=self.denoising_pos_params.all_atoms, + ) + + # Forward, loss, backward. #TODO update this with new signatures + with torch.cuda.amp.autocast(enabled=self.scaler is not None): + out = self._forward(batch) + loss = self._compute_loss(out, batch) + + # Compute metrics. + self.metrics = self._compute_metrics( + out, + batch, + self.evaluator, + self.metrics, + ) + self.metrics = self.evaluator.update("loss", loss.item(), self.metrics) + + loss = self.scaler.scale(loss) if self.scaler else loss + self._backward(loss) + + # Log metrics. + log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} + log_dict.update( + { + "lr": self.scheduler.get_lr(), + "epoch": self.epoch, + "step": self.step, + } + ) + if ( + self.step % self.config["cmd"]["print_every"] == 0 + or i == 0 + or i == (len(self.train_loader) - 1) + ) and distutils.is_master(): + log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] + logging.info(", ".join(log_str)) + self.metrics = {} + + if self.logger is not None: + self.logger.log( + log_dict, + step=self.step, + split="train", + ) + + if checkpoint_every != -1 and self.step % checkpoint_every == 0: + self.save(checkpoint_file="checkpoint.pt", training_state=True) + + # Evaluate on val set every `eval_every` iterations. + if self.step % eval_every == 0 or i == (len(self.train_loader) - 1): + if self.val_loader is not None: + if i == (len(self.train_loader) - 1): + self.save( + checkpoint_file="checkpoint.pt", + training_state=True, + ) + + val_metrics = self.validate( + split="val", disable_tqdm=disable_eval_tqdm + ) + self.update_best( + primary_metric, + val_metrics, + disable_eval_tqdm=disable_eval_tqdm, + ) + + if self.config["task"].get("eval_relaxations", False): + if "relax_dataset" not in self.config["task"]: + logging.warning( + "Cannot evaluate relaxations, relax_dataset not specified" + ) + else: + self.run_relaxations() + + if self.scheduler.scheduler_type == "ReduceLROnPlateau": + if self.step % eval_every == 0: + self.scheduler.step( + metrics=val_metrics[primary_metric]["metric"], + ) + else: + self.scheduler.step() + + torch.cuda.empty_cache() + + if checkpoint_every == -1: + self.save(checkpoint_file="checkpoint.pt", training_state=True) + + def _compute_loss(self, out, batch): + batch_size = batch.natoms.numel() + fixed = batch.fixed + mask = fixed == 0 + + loss = [] + for loss_fn in self.loss_functions: + target_name, loss_info = loss_fn + + if target_name == "forces" and batch.get("denoising_pos_forward", False): + denoising_pos_target = batch.noise_vec + if self.normalizers.get("denoising_pos_target", False): + denoising_pos_target = self.normalizers[ + "denoising_pos_target" + ].norm(denoising_pos_target) + + if hasattr(batch, "noise_mask"): + # for partially corrupted structures + target = batch.forces + if self.normalizers.get("forces", False): + target = self.normalizers["forces"].norm(target) + noise_mask = batch.noise_mask.view(-1, 1) + target = denoising_pos_target * noise_mask + target * (~noise_mask) + else: + target = denoising_pos_target + + pred = out["forces"] + natoms = batch.natoms + natoms = torch.repeat_interleave(natoms, natoms) + + force_mult = loss_info["coefficient"] + denoising_pos_mult = self.denoising_pos_params.denoising_pos_coefficient + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["train_on_free_atoms"] + ): + # If `all_atoms` == True when training on only free atoms, + # we also add noise to and denoise fixed atoms. + if self.denoising_pos_params.all_atoms: + if hasattr(batch, "noise_mask"): + mask = mask.view(-1, 1) | noise_mask + else: + mask = torch.ones_like( + mask, dtype=torch.bool, device=mask.device + ).view(-1, 1) + + if hasattr(batch, "noise_mask"): + # for partially corrupted structures + loss.append( + compute_atomwise_denoising_pos_and_force_hybrid_loss( + pred=pred, + target=target, + noise_mask=noise_mask, + force_mult=force_mult, + denoising_pos_mult=denoising_pos_mult, + mask=mask, + ) + ) + else: + target = target[mask] + pred = pred[mask] + natoms = natoms[mask] + + loss.append( + denoising_pos_mult + * loss_info["fn"]( + pred, + target, + natoms=natoms, + ) + ) + else: + if hasattr(batch, "noise_mask"): + # for partially corrupted structures + loss.append( + compute_atomwise_denoising_pos_and_force_hybrid_loss( + pred=pred, + target=target, + noise_mask=noise_mask, + force_mult=force_mult, + denoising_pos_mult=denoising_pos_mult, + mask=None, + ) + ) + else: + loss.append( + denoising_pos_mult + * loss_info["fn"]( + pred, + target, + natoms=natoms, + ) + ) + + else: + target = batch[target_name] + pred = out[target_name] + natoms = batch.natoms + natoms = torch.repeat_interleave(natoms, natoms) + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["train_on_free_atoms"] + ): + target = target[mask] + pred = pred[mask] + natoms = natoms[mask] + + num_atoms_in_batch = natoms.numel() + + ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 + if self.output_targets[target_name]["level"] == "atom": + target = target.view(num_atoms_in_batch, -1) + else: + target = target.view(batch_size, -1) + + # to keep the loss coefficient weights balanced we remove linear references + # subtract element references from target data + if target_name in self.elementrefs: + target = self.elementrefs[target_name].dereference(target, batch) + # normalize the targets data + if target_name in self.normalizers: + target = self.normalizers[target_name].norm(target) + + mult = loss_info["coefficient"] + + loss.append( + mult + * loss_info["fn"]( + pred, + target, + natoms=batch.natoms, + ) + ) + + # Sanity check to make sure the compute graph is correct. + for lc in loss: + assert hasattr(lc, "grad_fn") + + return sum(loss) + + def _compute_metrics(self, out, batch, evaluator, metrics=None): + if metrics is None: + metrics = {} + # this function changes the values in the out dictionary, + # make a copy instead of changing them in the callers version + out = {k: v.clone() for k, v in out.items()} + + natoms = batch.natoms + batch_size = natoms.numel() + + ### Retrieve free atoms + fixed = batch.fixed + mask = fixed == 0 + + s_idx = 0 + natoms_free = [] + for _natoms in natoms: + natoms_free.append(torch.sum(mask[s_idx : s_idx + _natoms]).item()) + s_idx += _natoms + natoms = torch.LongTensor(natoms_free).to(self.device) + + denoising_pos_forward = bool(batch.get("denoising_pos_forward", False)) + + targets = {} + for target_name in self.output_targets: + num_atoms_in_batch = batch.natoms.sum() + + if denoising_pos_forward and target_name == "forces": + if hasattr(batch, "noise_mask"): + force_target = batch.forces + denoising_pos_target = batch.noise_vec + noise_mask = batch.noise_mask + s2ef_index = torch.where(noise_mask == 0) + denoising_pos_index = torch.where(noise_mask == 1) + noise_mask_tensor = noise_mask.view(-1, 1) + targets["forces"] = ( + denoising_pos_target * noise_mask_tensor + + force_target * (~noise_mask_tensor) + ) + targets["noise_mask"] = noise_mask + else: + targets["forces"] = batch.noise_vec + + if "denoising_pos_target" in self.normalizers: + if hasattr(batch, "noise_mask"): + out["forces"][denoising_pos_index] = self.normalizers[ + "denoising_pos_target" + ].denorm(out["forces"][denoising_pos_index]) + else: + out["forces"] = self.normalizers["denoising_pos_target"].denorm( + out["forces"] + ) + + if hasattr(batch, "noise_mask"): + out["forces"][s2ef_index] = self.normalizers["forces"].denorm( + out["forces"][s2ef_index] + ) + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["eval_on_free_atoms"] + ): + if self.denoising_pos_params.all_atoms: + if hasattr(batch, "noise_mask"): + mask = mask | noise_mask + else: + mask = torch.ones_like( + mask, dtype=torch.bool, device=mask.device + ) + + targets["forces"] = targets["forces"][mask] + out["forces"] = out["forces"][mask] + num_atoms_in_batch = natoms.sum() + if "noise_mask" in targets: + targets["noise_mask"] = targets["noise_mask"][mask] + else: + target = batch[target_name] + + if ( + self.output_targets[target_name]["level"] == "atom" + and self.output_targets[target_name]["eval_on_free_atoms"] + ): + target = target[mask] + out[target_name] = out[target_name][mask] + num_atoms_in_batch = natoms.sum() + + ### reshape accordingly: num_atoms_in_batch, -1 or num_systems_in_batch, -1 + if self.output_targets[target_name]["level"] == "atom": + target = target.view(num_atoms_in_batch, -1) + else: + target = target.view(batch_size, -1) + + out[target_name] = self._denorm_preds( + target_name, out[target_name], batch + ) + targets[target_name] = target + + targets["natoms"] = natoms + out["natoms"] = natoms + + return denoising_pos_eval( + evaluator, + out, + targets, + denoising_targets=self.denoising_targets, + prev_metrics=metrics, + denoising_pos_forward=denoising_pos_forward, + ) + + @torch.no_grad() + def predict( + self, + data_loader, + per_image: bool = True, + results_file: str | None = None, + disable_tqdm: bool = False, + ): + if self.is_debug and per_image: + raise FileNotFoundError("Predictions require debug mode to be turned off.") + + ensure_fitted(self._unwrapped_model, warn=True) + + if distutils.is_master() and not disable_tqdm: + logging.info("Predicting on test.") + assert isinstance( + data_loader, + ( + torch.utils.data.dataloader.DataLoader, + torch_geometric.data.Batch, + ), + ) + rank = distutils.get_rank() + + if isinstance(data_loader, torch_geometric.data.Batch): + data_loader = [data_loader] + + self.model.eval() + if self.ema is not None: + self.ema.store() + self.ema.copy_to() + + predictions = defaultdict(list) + + for _, batch in tqdm( + enumerate(data_loader), + total=len(data_loader), + position=rank, + desc=f"device {rank}", + disable=disable_tqdm, + ): + with torch.cuda.amp.autocast(enabled=self.scaler is not None): + out = self._forward(batch) + + for key in out: + out[key] = out[key].float() + + for target_key in self.config["outputs"]: + pred = self._denorm_preds(target_key, out[target_key], batch) + + if per_image: + ### Save outputs in desired precision, default float16 + if ( + self.config["outputs"][target_key].get( + "prediction_dtype", "float16" + ) + == "float32" + or self.config["task"].get("prediction_dtype", "float16") + == "float32" + or self.config["task"].get("dataset", "lmdb") == "oc22_lmdb" + ): + dtype = torch.float32 + else: + dtype = torch.float16 + + pred = pred.detach().cpu().to(dtype) + + ### Split predictions into per-image predictions + if self.config["outputs"][target_key]["level"] == "atom": + batch_natoms = batch.natoms + batch_fixed = batch.fixed + per_image_pred = torch.split(pred, batch_natoms.tolist()) + + ### Save out only free atom, EvalAI does not need fixed atoms + _per_image_fixed = torch.split( + batch_fixed, batch_natoms.tolist() + ) + _per_image_free_preds = [ + _pred[(fixed == 0).tolist()].numpy() + for _pred, fixed in zip(per_image_pred, _per_image_fixed) + ] + _chunk_idx = np.array( + [free_pred.shape[0] for free_pred in _per_image_free_preds] + ) + per_image_pred = _per_image_free_preds + ### Assumes system level properties are of the same dimension + else: + per_image_pred = pred.numpy() + _chunk_idx = None + + predictions[f"{target_key}"].extend(per_image_pred) + ### Backwards compatibility, retain 'chunk_idx' for forces. + if _chunk_idx is not None: + if target_key == "forces": + predictions["chunk_idx"].extend(_chunk_idx) + else: + predictions[f"{target_key}_chunk_idx"].extend(_chunk_idx) + else: + predictions[f"{target_key}"] = pred.detach() + + if not per_image: + return predictions + + ### Get unique system identifiers + sids = ( + batch.sid.tolist() if isinstance(batch.sid, torch.Tensor) else batch.sid + ) + ## Support naming structure for OC20 S2EF + if "fid" in batch: + fids = ( + batch.fid.tolist() + if isinstance(batch.fid, torch.Tensor) + else batch.fid + ) + systemids = [f"{sid}_{fid}" for sid, fid in zip(sids, fids)] + else: + systemids = [f"{sid}" for sid in sids] + + predictions["ids"].extend(systemids) + + self.save_results(predictions, results_file) + + if self.ema: + self.ema.restore() + + return predictions diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 9a13faed6..a8976773c 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -47,7 +47,7 @@ class OCPTrainer(BaseTrainer): Args: task (dict): Task configuration. model (dict): Model configuration. - outputs (dict): Output property configuration. + outputs (dict): Dictionary of model output configuration. dataset (dict): Dataset configuration. The dataset needs to be a SinglePointLMDB dataset. optimizer (dict): Optimizer configuration. loss_functions (dict): Loss function configuration. @@ -55,6 +55,8 @@ class OCPTrainer(BaseTrainer): identifier (str): Experiment identifier that is appended to log directory. run_dir (str, optional): Path to the run directory where logs are to be saved. (default: :obj:`None`) + timestamp_id (str, optional): timestamp identifier. + run_dir (str, optional): Run directory used to save checkpoints and results. is_debug (bool, optional): Run in debug mode. (default: :obj:`False`) print_every (int, optional): Frequency of printing logs. @@ -63,10 +65,17 @@ class OCPTrainer(BaseTrainer): (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`wandb`) + local_rank (int, optional): Local rank of the process, only applicable for distributed training. + (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) + cpu (bool): If True will run on CPU. Default is False, will attempt to use cuda. + name (str): Trainer name. slurm (dict): Slurm configuration. Currently just for keeping track. (default: :obj:`{}`) + gp_gpus (int, optional): Number of graph parallel GPUs. + inference_only (bool): If true trainer will be loaded for inference only. + (ie datasets, optimizer, schedular, etc, will not be instantiated) """ def __init__( @@ -91,7 +100,7 @@ def __init__( amp: bool = False, cpu: bool = False, name: str = "ocp", - slurm=None, + slurm: dict | None = None, gp_gpus: int | None = None, inference_only: bool = False, ): @@ -260,6 +269,7 @@ def _forward(self, batch): ), f"we need to know which property to match the target to, please specify the property field in the task config, current config: {self.output_targets[target_key]}" prop = self.output_targets[target_key]["property"] pred = out[target_key][prop] + # TODO clean up this logic to reconstruct a tensor from its predicted decomposition elif "decomposition" in self.output_targets[target_key]: _max_rank = 0 From 3f695efd4877db6cdb4adde9d5c8ca8008ce2a75 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Tue, 3 Dec 2024 10:52:44 -0800 Subject: [PATCH 5/7] OptimizableBatch and stress relaxations (#718) * remove r_edges, radius, max_neigh and add deprecation warning * edit typing and dont use dicts as default * use super() and remove overkill deprecation warning * set implemented_properties from config * make determine step a method * allow calculator to operate on batches * only update if old config is used * reshape properties * no test classes in ase calculator * yaml load fix * use mappingproxy * expressive import * remove duplicated code * optimizable batch class for ase compatible batch relaxations * fix optimizable batch * optimizable goodies * apply force constraints * use optimizable batch instead and remove torchcalc * update ml relaxations to use optimizable batch correctly * force_consistent check for ASE compat * force_consistent check for ASE compat * check force_consistent * init docs in lbfgs * unitcellfilter for batch relaxations * ruff * UnitCellOptimizable as child class instead of filter * allow running unit cell relaxations * ruff * no grad in run_relaxations * make batched_dot and determine_step methods * imports * rename to optimizableunitcellbatch * allow passing energy and forces explicitly to batch to atoms * check convergence in optimizable and allow passing general results to atoms_from_batch * relaxation test * unit tests * move update mask to optimizable * use energy instead of y * all setting/getting positions and convergence in optimizable * more (unfinished) tests * backwards compatible test * minor fixes * code cleanup * add/fix tests * fix lbfgs * assert using norm * add eps to masked batches if using ASE optimizers * match iterations from previous implementation * use float64 for forces * float32 * use energy_relaxed instead of y_relaxed * energy_relaxed and more explicit error msg * default to batch_size 1 if not set in config * keep float64 training * rename y_relaxed -> energy_relaxed * rm expcell batch * convenience commit from no_experimental_resolve * use numatoms tensor for cell factor * remove positions tests (wrapping atoms gives different results) * allow wrapping positions in batch to atoms * fix test * wrap_positions in batch_to_atoms * take a2g properties from model * test lbfgs traj writes * remove comments * use model generate graph * fix cell_factor * fix using model in ddp * fix r_edges in OCPcalculator * write initial and final structure if save_full is false * check unique atoms saved in trajectory * tighter tol * update ASE release comment * remove cumulative mask option * remove left over cumulative_mask * fix batching when sids as str * do not try to fetch energy and forces if no explicit results * accept Path objects * clean up setting defaults * expose ml_relax in relaxation * force set r_pbc True * make relax_opt optional * no ema on inference only * define ema none to avoid issues * lower force threshold to make sure test does not converge * clean up exception msg * allow strings in batch * remove device argument from lbfgs * minor cleanup * fix optimizable import * do not pass device in ml_relax * simplify enforce max neighbors * fix tests (still not testing stress) * pin sphinx autoapi * typo in version --------- Co-authored-by: zulissimeta <122578103+zulissimeta@users.noreply.github.com> Co-authored-by: Zack Ulissi --- packages/fairchem-core/pyproject.toml | 2 +- .../core/common/relaxation/__init__.py | 13 + .../core/common/relaxation/ase_utils.py | 108 +++- .../core/common/relaxation/ml_relaxation.py | 117 ++-- .../core/common/relaxation/optimizable.py | 547 ++++++++++++++++++ .../common/relaxation/optimizers/__init__.py | 12 + .../relaxation/optimizers/lbfgs_torch.py | 238 ++++---- src/fairchem/core/datasets/ase_datasets.py | 16 +- src/fairchem/core/models/base.py | 12 +- .../core/preprocessing/atoms_to_graphs.py | 2 +- src/fairchem/core/trainers/base_trainer.py | 7 +- src/fairchem/core/trainers/ocp_trainer.py | 13 +- tests/core/common/conftest.py | 33 ++ tests/core/common/test_ase_calculator.py | 3 + tests/core/common/test_lbfgs_torch.py | 66 +++ tests/core/common/test_optimizable.py | 110 ++++ tests/core/datasets/test_ase_datasets.py | 10 +- 17 files changed, 1068 insertions(+), 241 deletions(-) create mode 100644 src/fairchem/core/common/relaxation/optimizable.py create mode 100644 tests/core/common/conftest.py create mode 100644 tests/core/common/test_lbfgs_torch.py create mode 100644 tests/core/common/test_optimizable.py diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index ee92db45d..dfd5c671a 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ [project.optional-dependencies] # add optional dependencies to be installed as pip install fairchem.core[dev] dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"] -docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi", "umap-learn", "vdict"] +docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "umap-learn", "vdict"] adsorbml = ["dscribe","x3dase","scikit-image"] [project.scripts] diff --git a/src/fairchem/core/common/relaxation/__init__.py b/src/fairchem/core/common/relaxation/__init__.py index e69de29bb..1700e0040 100644 --- a/src/fairchem/core/common/relaxation/__init__.py +++ b/src/fairchem/core/common/relaxation/__init__.py @@ -0,0 +1,13 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from .ml_relaxation import ml_relax +from .optimizable import OptimizableBatch, OptimizableUnitCellBatch + +__all__ = ["ml_relax", "OptimizableBatch", "OptimizableUnitCellBatch"] diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 2dacce2cb..5a9302d88 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -14,13 +14,15 @@ import copy import logging -from typing import ClassVar +from types import MappingProxyType +from typing import TYPE_CHECKING import torch from ase import Atoms from ase.calculators.calculator import Calculator -from ase.calculators.singlepoint import SinglePointCalculator as sp +from ase.calculators.singlepoint import SinglePointCalculator from ase.constraints import FixAtoms +from ase.geometry import wrap_positions from fairchem.core.common.registry import registry from fairchem.core.common.utils import ( @@ -33,51 +35,93 @@ from fairchem.core.models.model_registry import model_name_to_local_file from fairchem.core.preprocessing import AtomsToGraphs +if TYPE_CHECKING: + from pathlib import Path -def batch_to_atoms(batch): + from torch_geometric.data import Batch + + +# system level model predictions have different shapes than expected by ASE +ASE_PROP_RESHAPE = MappingProxyType( + {"stress": (-1, 3, 3), "dielectric_tensor": (-1, 3, 3)} +) + + +def batch_to_atoms( + batch: Batch, + results: dict[str, torch.Tensor] | None = None, + wrap_pos: bool = True, + eps: float = 1e-7, +) -> list[Atoms]: + """Convert a data batch to ase Atoms + + Args: + batch: data batch + results: dictionary with predicted result tensors that will be added to a SinglePointCalculator. If no results + are given no calculator will be added to the atoms objects. + wrap_pos: wrap positions back into the cell. + eps: Small number to prevent slightly negative coordinates from being wrapped. + + Returns: + list of Atoms + """ n_systems = batch.natoms.shape[0] natoms = batch.natoms.tolist() numbers = torch.split(batch.atomic_numbers, natoms) fixed = torch.split(batch.fixed.to(torch.bool), natoms) - forces = torch.split(batch.force, natoms) + if results is not None: + results = { + key: val.view(ASE_PROP_RESHAPE.get(key, -1)).tolist() + if len(val) == len(batch) + else [v.cpu().detach().numpy() for v in torch.split(val, natoms)] + for key, val in results.items() + } + positions = torch.split(batch.pos, natoms) tags = torch.split(batch.tags, natoms) cells = batch.cell - energies = batch.energy.view(-1).tolist() atoms_objects = [] for idx in range(n_systems): + pos = positions[idx].cpu().detach().numpy() + cell = cells[idx].cpu().detach().numpy() + + # TODO take pbc from data + if wrap_pos: + pos = wrap_positions(pos, cell, pbc=[True, True, True], eps=eps) + atoms = Atoms( numbers=numbers[idx].tolist(), - positions=positions[idx].cpu().detach().numpy(), + cell=cell, + positions=pos, tags=tags[idx].tolist(), - cell=cells[idx].cpu().detach().numpy(), constraint=FixAtoms(mask=fixed[idx].tolist()), pbc=[True, True, True], ) - calc = sp( - atoms=atoms, - energy=energies[idx], - forces=forces[idx].cpu().detach().numpy(), - ) - atoms.set_calculator(calc) + + if results is not None: + calc = SinglePointCalculator( + atoms=atoms, **{key: val[idx] for key, val in results.items()} + ) + atoms.set_calculator(calc) + atoms_objects.append(atoms) return atoms_objects class OCPCalculator(Calculator): - implemented_properties: ClassVar[list[str]] = ["energy", "forces"] + """ASE based calculator using an OCP model""" + + _reshaped_props = ASE_PROP_RESHAPE def __init__( self, config_yml: str | None = None, - checkpoint_path: str | None = None, + checkpoint_path: str | Path | None = None, model_name: str | None = None, local_cache: str | None = None, trainer: str | None = None, - cutoff: int = 6, - max_neighbors: int = 50, cpu: bool = True, seed: int | None = None, ) -> None: @@ -96,16 +140,12 @@ def __init__( Directory to save pretrained model checkpoints. trainer (str): OCP trainer to be used. "forces" for S2EF, "energy" for IS2RE. - cutoff (int): - Cutoff radius to be used for data preprocessing. - max_neighbors (int): - Maximum amount of neighbors to store for a given atom. cpu (bool): Whether to load and run the model on CPU. Set `False` for GPU. """ setup_imports() setup_logging() - Calculator.__init__(self) + super().__init__() if model_name is not None: if checkpoint_path is not None: @@ -165,9 +205,8 @@ def __init__( ### backwards compatability with OCP v<2.0 config = update_config(config) - # Save config so obj can be transported over network (pkl) self.config = copy.deepcopy(config) - self.config["checkpoint"] = checkpoint_path + self.config["checkpoint"] = str(checkpoint_path) del config["dataset"]["src"] self.trainer = registry.get_trainer_class(config["trainer"])( @@ -199,14 +238,13 @@ def __init__( self.trainer.set_seed(seed) self.a2g = AtomsToGraphs( - max_neigh=max_neighbors, - radius=cutoff, r_energy=False, r_forces=False, r_distances=False, - r_edges=False, r_pbc=True, + r_edges=not self.trainer.model.otf_graph, # otf graph should not be a property of the model... ) + self.implemented_properties = list(self.config["outputs"].keys()) def load_checkpoint( self, checkpoint_path: str, checkpoint: dict | None = None @@ -217,6 +255,8 @@ def load_checkpoint( Args: checkpoint_path: string Path to trained model + checkpoint: dict + A pretrained checkpoint dict """ try: self.trainer.load_checkpoint( @@ -225,14 +265,20 @@ def load_checkpoint( except NotImplementedError: logging.warning("Unable to load checkpoint!") - def calculate(self, atoms: Atoms, properties, system_changes) -> None: - Calculator.calculate(self, atoms, properties, system_changes) - data_object = self.a2g.convert(atoms) - batch = data_list_collater([data_object], otf_graph=True) + def calculate(self, atoms: Atoms | Batch, properties, system_changes) -> None: + """Calculate implemented properties for a single Atoms object or a Batch of them.""" + super().calculate(atoms, properties, system_changes) + if isinstance(atoms, Atoms): + data_object = self.a2g.convert(atoms) + batch = data_list_collater([data_object], otf_graph=True) + else: + batch = atoms predictions = self.trainer.predict(batch, per_image=False, disable_tqdm=True) for key in predictions: _pred = predictions[key] _pred = _pred.item() if _pred.numel() == 1 else _pred.cpu().numpy() + if key in OCPCalculator._reshaped_props: + _pred = _pred.reshape(OCPCalculator._reshaped_props.get(key)).squeeze() self.results[key] = _pred diff --git a/src/fairchem/core/common/relaxation/ml_relaxation.py b/src/fairchem/core/common/relaxation/ml_relaxation.py index 406b6b1cc..bf5eb3cac 100644 --- a/src/fairchem/core/common/relaxation/ml_relaxation.py +++ b/src/fairchem/core/common/relaxation/ml_relaxation.py @@ -10,6 +10,7 @@ import logging from collections import deque from pathlib import Path +from typing import TYPE_CHECKING import torch from torch_geometric.data import Batch @@ -17,70 +18,94 @@ from fairchem.core.common.typing import assert_is_instance from fairchem.core.datasets.lmdb_dataset import data_list_collater -from .optimizers.lbfgs_torch import LBFGS, TorchCalc +from .optimizable import OptimizableBatch, OptimizableUnitCellBatch +from .optimizers.lbfgs_torch import LBFGS + +if TYPE_CHECKING: + from fairchem.core.trainers import BaseTrainer def ml_relax( - batch, - model, + batch: Batch, + model: BaseTrainer, steps: int, fmax: float, - relax_opt, - save_full_traj, - device: str = "cuda:0", - transform=None, - early_stop_batch: bool = False, + relax_opt: dict[str] | None = None, + relax_cell: bool = False, + relax_volume: bool = False, + save_full_traj: bool = True, + transform: torch.nn.Module | None = None, + mask_converged: bool = True, ): - """ - Runs ML-based relaxations. + """Runs ML-based relaxations. + Args: - batch: object - model: object - steps: int - Max number of steps in the structure relaxation. - fmax: float - Structure relaxation terminates when the max force - of the system is no bigger than fmax. - relax_opt: str - Optimizer and corresponding parameters to be used for structure relaxations. - save_full_traj: bool - Whether to save out the full ASE trajectory. If False, only save out initial and final frames. + batch: a data batch object. + model: a trainer object with model. + steps: Max number of steps in the structure relaxation. + fmax: Structure relaxation terminates when the max force of the system is no bigger than fmax. + relax_opt: Optimizer parameters to be used for structure relaxations. + relax_cell: if true will use stress predictions to relax crystallographic cell. + The model given must predict stress + relax_volume: if true will relax the cell isotropically. the given model must predict stress. + save_full_traj: Whether to save out the full ASE trajectory. If False, only save out initial and final frames. + mask_converged: whether to mask batches where all atoms are below convergence threshold + cumulative_mask: if true, once system is masked then it remains masked even if new predictions give forces + above threshold, ie. once masked always masked. Note if this is used make sure to check convergence with + the same fmax always """ + relax_opt = relax_opt or {} + # if not pbc is set, ignore it when comparing batches + if not hasattr(batch, "pbc"): + OptimizableBatch.ignored_changes = {"pbc"} + batches = deque([batch]) relaxed_batches = [] while batches: batch = batches.popleft() oom = False ids = batch.sid - calc = TorchCalc(model, transform) + + # clone the batch otherwise you can not run batch.to_data_list + # see https://github.com/pyg-team/pytorch_geometric/issues/8439#issuecomment-1826747915 + if relax_cell or relax_volume: + optimizable = OptimizableUnitCellBatch( + batch.clone(), + trainer=model, + transform=transform, + mask_converged=mask_converged, + hydrostatic_strain=relax_volume, + ) + else: + optimizable = OptimizableBatch( + batch.clone(), + trainer=model, + transform=transform, + mask_converged=mask_converged, + ) # Run ML-based relaxation - traj_dir = relax_opt.get("traj_dir", None) + traj_dir = relax_opt.get("traj_dir") + relax_opt.update({"traj_dir": Path(traj_dir) if traj_dir is not None else None}) + optimizer = LBFGS( - batch, - calc, - maxstep=relax_opt.get("maxstep", 0.2), - memory=relax_opt["memory"], - damping=relax_opt.get("damping", 1.2), - alpha=relax_opt.get("alpha", 80.0), - device=device, + optimizable_batch=optimizable, save_full_traj=save_full_traj, - traj_dir=Path(traj_dir) if traj_dir is not None else None, traj_names=ids, - early_stop_batch=early_stop_batch, + **relax_opt, ) e: RuntimeError | None = None try: - relaxed_batch = optimizer.run(fmax=fmax, steps=steps) - relaxed_batches.append(relaxed_batch) + optimizer.run(fmax=fmax, steps=steps) + relaxed_batches.append(optimizable.batch) except RuntimeError as err: e = err oom = True torch.cuda.empty_cache() if oom: - # move OOM recovery code outside of except clause to allow tensors to be freed. + # move OOM recovery code outside off except clause to allow tensors to be freed. data_list = batch.to_data_list() if len(data_list) == 1: raise assert_is_instance(e, RuntimeError) @@ -88,7 +113,23 @@ def ml_relax( f"Failed to relax batch with size: {len(data_list)}, splitting into two..." ) mid = len(data_list) // 2 - batches.appendleft(data_list_collater(data_list[:mid])) - batches.appendleft(data_list_collater(data_list[mid:])) + batches.appendleft( + data_list_collater(data_list[:mid], otf_graph=optimizable.otf_graph) + ) + batches.appendleft( + data_list_collater(data_list[mid:], otf_graph=optimizable.otf_graph) + ) + + # reset for good measure + OptimizableBatch.ignored_changes = {} + + relaxed_batch = Batch.from_data_list(relaxed_batches) + + # Batch.from_data_list is not intended to be used with a list of batches, so when sid is a list of str + # it will be incorrectly collated as a list of lists for each batch. + # but we can not use to_data_list in the relaxed batches (since they have been changed, see linked comment above). + # So instead just manually fix it for now. Remove this once pyg dependency is removed + if isinstance(relaxed_batch.sid, list): + relaxed_batch.sid = [sid for sid_list in relaxed_batch.sid for sid in sid_list] - return Batch.from_data_list(relaxed_batches) + return relaxed_batch diff --git a/src/fairchem/core/common/relaxation/optimizable.py b/src/fairchem/core/common/relaxation/optimizable.py new file mode 100644 index 000000000..c40f46126 --- /dev/null +++ b/src/fairchem/core/common/relaxation/optimizable.py @@ -0,0 +1,547 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Code based on ase.optimize +""" + +from __future__ import annotations + +from functools import cached_property +from types import SimpleNamespace +from typing import TYPE_CHECKING, ClassVar + +import numpy as np +import torch +from ase.calculators.calculator import PropertyNotImplementedError +from ase.stress import voigt_6_to_full_3x3_stress +from torch_scatter import scatter + +from fairchem.core.common.relaxation.ase_utils import batch_to_atoms + +# this can be removed after pinning ASE dependency >= 3.23 +try: + from ase.optimize.optimize import Optimizable +except ImportError: + + class Optimizable: + pass + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from ase import Atoms + from numpy.typing import NDArray + from torch_geometric.data import Batch + + from fairchem.core.trainers import BaseTrainer + + +ALL_CHANGES: set[str] = { + "pos", + "atomic_numbers", + "cell", + "pbc", +} + + +def compare_batches( + batch1: Batch | None, + batch2: Batch, + tol: float = 1e-6, + excluded_properties: set[str] | None = None, +) -> list[str]: + """Compare properties between two batches + + Args: + batch1: atoms batch + batch2: atoms batch + tol: tolerance used to compare equility of floating point properties + excluded_properties: list of properties to exclude from comparison + + Returns: + list of system changes, property names that are differente between batch1 and batch2 + """ + system_changes = [] + + if batch1 is None: + system_changes = ALL_CHANGES + else: + properties_to_check = set(ALL_CHANGES) + if excluded_properties: + properties_to_check -= set(excluded_properties) + + # Check properties that aren't + for prop in ALL_CHANGES: + if prop in properties_to_check: + properties_to_check.remove(prop) + if not torch.allclose( + getattr(batch1, prop), getattr(batch2, prop), atol=tol + ): + system_changes.append(prop) + + return system_changes + + +class OptimizableBatch(Optimizable): + """A Batch version of ase Optimizable Atoms + + This class can be used with ML relaxations in fairchem.core.relaxations.ml_relaxation + or in ase relaxations classes, i.e. ase.optimize.lbfgs + """ + + ignored_changes: ClassVar[set[str]] = set() + + def __init__( + self, + batch: Batch, + trainer: BaseTrainer, + transform: torch.nn.Module | None = None, + mask_converged: bool = True, + numpy: bool = False, + masked_eps: float = 1e-8, + ): + """Initialize Optimizable Batch + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + mask_converged: if true will mask systems in batch that are already converged + numpy: whether to cast results to numpy arrays + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + """ + self.batch = batch.to(trainer.device) + self.trainer = trainer + self.transform = transform + self.numpy = numpy + self.mask_converged = mask_converged + self._cached_batch = None + self._update_mask = None + self.torch_results = {} + self.results = {} + self._eps = masked_eps + + self.otf_graph = True # trainer._unwrapped_model.otf_graph + if not self.otf_graph and "edge_index" not in self.batch: + self.update_graph() + + @property + def device(self): + return self.trainer.device + + @property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch.""" + return self.batch.batch + + @property + def converged_mask(self): + if self._update_mask is not None: + return torch.logical_not(self._update_mask) + return None + + @property + def update_mask(self): + if self._update_mask is None: + return torch.ones(len(self.batch), dtype=bool) + return self._update_mask + + def check_state(self, batch: Batch, tol: float = 1e-12) -> bool: + """Check for any system changes since last calculation.""" + return compare_batches( + self._cached_batch, + batch, + tol=tol, + excluded_properties=set(self.ignored_changes), + ) + + def _predict(self) -> None: + """Run prediction if batch has any changes.""" + system_changes = self.check_state(self.batch) + if len(system_changes) > 0: + self.torch_results = self.trainer.predict( + self.batch, per_image=False, disable_tqdm=True + ) + # save only subset of props in simple namespace instead of cloning the whole batch to save memory + changes = ALL_CHANGES - set(self.ignored_changes) + self._cached_batch = SimpleNamespace( + **{prop: self.batch[prop].clone() for prop in changes} + ) + + def get_property(self, name, no_numpy: bool = False) -> torch.Tensor | NDArray: + """Get a predicted property by name.""" + self._predict() + if self.numpy: + self.results = { + key: pred.item() if pred.numel() == 1 else pred.cpu().numpy() + for key, pred in self.torch_results.items() + } + else: + self.results = self.torch_results + + if name not in self.results: + raise PropertyNotImplementedError(f"{name} not present in this calculation") + + return self.results[name] if no_numpy is False else self.torch_results[name] + + def get_positions(self) -> torch.Tensor | NDArray: + """Get the batch positions""" + pos = self.batch.pos.clone() + if self.numpy: + if self.mask_converged: + pos[~self.update_mask[self.batch.batch]] = self._eps + pos = pos.cpu().numpy() + + return pos + + def set_positions(self, positions: torch.Tensor | NDArray) -> None: + """Set the atom positions in the batch.""" + if isinstance(positions, np.ndarray): + positions = torch.tensor(positions) + + positions = positions.to(dtype=torch.float32, device=self.device) + if self.mask_converged and self._update_mask is not None: + mask = self.update_mask[self.batch.batch] + self.batch.pos[mask] = positions[mask] + else: + self.batch.pos = positions + + if not self.otf_graph: + self.update_graph() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get predicted batch forces.""" + forces = self.get_property("forces", no_numpy=no_numpy) + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + if isinstance(forces, np.ndarray): + fixed_idx = fixed_idx.tolist() + forces[fixed_idx] = 0.0 + return forces + + def get_potential_energy(self, **kwargs) -> torch.Tensor | NDArray: + """Get predicted energy as the sum of all batch energies.""" + # ASE 3.22.1 expects a check for force_consistent calculations + if kwargs.get("force_consistent", False) is True: + raise PropertyNotImplementedError( + "force_consistent calculations are not implemented" + ) + if ( + len(self.batch) == 1 + ): # unfortunately batch size 1 returns a float, not a tensor + return self.get_property("energy") + return self.get_property("energy").sum() + + def get_potential_energies(self) -> torch.Tensor | NDArray: + """Get the predicted energy for each system in batch.""" + return self.get_property("energy") + + def get_cells(self) -> torch.Tensor: + """Get batch crystallographic cells.""" + return self.batch.cell + + def set_cells(self, cells: torch.Tensor | NDArray) -> None: + """Set batch cells.""" + assert self.batch.cell.shape == cells.shape, "Cell shape mismatch" + if isinstance(cells, np.ndarray): + cells = torch.tensor(cells, dtype=torch.float32, device=self.device) + cells = cells.to(dtype=torch.float32, device=self.device) + self.batch.cell[self.update_mask] = cells[self.update_mask] + + def get_volumes(self) -> torch.Tensor: + """Get a tensor of volumes for each cell in batch""" + cells = self.get_cells() + return torch.linalg.det(cells) + + def iterimages(self) -> Batch: + # XXX document purpose of iterimages - this is just needed to work with ASE optimizers + yield self.batch + + def get_max_forces( + self, forces: torch.Tensor | None = None, apply_constraint: bool = False + ) -> torch.Tensor: + """Get the maximum forces per structure in batch""" + if forces is None: + forces = self.get_forces(apply_constraint=apply_constraint, no_numpy=True) + return scatter((forces**2).sum(axis=1).sqrt(), self.batch_indices, reduce="max") + + def converged( + self, + forces: torch.Tensor | NDArray | None, + fmax: float, + max_forces: torch.Tensor | None = None, + ) -> bool: + """Check if norm of all predicted forces are below fmax""" + if forces is not None: + if isinstance(forces, np.ndarray): + forces = torch.tensor(forces, device=self.device, dtype=torch.float32) + max_forces = self.get_max_forces(forces) + elif max_forces is None: + max_forces = self.get_max_forces() + + update_mask = max_forces.ge(fmax) + # update cached mask + if self.mask_converged: + if self._update_mask is None: + self._update_mask = update_mask + else: + # some models can have random noise in their predictions, so the mask is updated by + # keeping all previously converged structures masked even if new force predictions + # push it slightly above threshold + self._update_mask = torch.logical_and(self._update_mask, update_mask) + update_mask = self._update_mask + + return not torch.any(update_mask).item() + + def get_atoms_list(self) -> list[Atoms]: + """Get ase Atoms objects corresponding to the batch""" + self._predict() # in case no predictions have been run + return batch_to_atoms(self.batch, results=self.torch_results) + + def update_graph(self): + """Update the graph if model does not use otf_graph.""" + graph = self.trainer._unwrapped_model.generate_graph(self.batch) + self.batch.edge_index = graph.edge_index + self.batch.cell_offsets = graph.cell_offsets + self.batch.neighbors = graph.neighbors + if self.transform is not None: + self.batch = self.transform(self.batch) + + def __len__(self) -> int: + # TODO: this might be changed in ASE to be 3 * len(self.atoms) + return len(self.batch.pos) + + +class OptimizableUnitCellBatch(OptimizableBatch): + """Modify the supercell and the atom positions in relaxations. + + Based on ase UnitCellFilter to work on data batches + """ + + def __init__( + self, + batch: Batch, + trainer: BaseTrainer, + transform: torch.nn.Module | None = None, + numpy: bool = False, + mask_converged: bool = True, + mask: Sequence[bool] | None = None, + cell_factor: float | torch.Tensor | None = None, + hydrostatic_strain: bool = False, + constant_volume: bool = False, + scalar_pressure: float = 0.0, + masked_eps: float = 1e-8, + ): + """Create a filter that returns the forces and unit cell stresses together, for simultaneous optimization. + + For full details see: + E. B. Tadmor, G. S. Smith, N. Bernstein, and E. Kaxiras, + Phys. Rev. B 59, 235 (1999) + + Args: + batch: A batch of atoms graph data + model: An instance of a BaseTrainer derived class + transform: graph transform + numpy: whether to cast results to numpy arrays + mask_converged: if true will mask systems in batch that are already converged + mask: a boolean mask specifying which strain components are allowed to relax + cell_factor: + Factor by which deformation gradient is multiplied to put + it on the same scale as the positions when assembling + the combined position/cell vector. The stress contribution to + the forces is scaled down by the same factor. This can be thought + of as a very simple preconditioner. Default is number of atoms + which gives approximately the correct scaling. + hydrostatic_strain: + Constrain the cell by only allowing hydrostatic deformation. + The virial tensor is replaced by np.diag([np.trace(virial)]*3). + constant_volume: + Project out the diagonal elements of the virial tensor to allow + relaxations at constant volume, e.g. for mapping out an + energy-volume curve. Note: this only approximately conserves + the volume and breaks energy/force consistency so can only be + used with optimizers that do require a line minimisation + (e.g. FIRE). + scalar_pressure: + Applied pressure to use for enthalpy pV term. As above, this + breaks energy/force consistency. + masked_eps: masking systems that are converged when using ASE optimizers results in divisions by zero + from zero differences in masked positions at future steps, we add a small number to prevent this. + """ + super().__init__( + batch=batch, + trainer=trainer, + transform=transform, + numpy=numpy, + mask_converged=mask_converged, + masked_eps=masked_eps, + ) + + self.orig_cells = self.get_cells().clone() + self.stress = None + + if mask is None: + mask = torch.eye(3, device=self.device) + + # TODO make sure mask is on GPU + if mask.shape == (6,): + self.mask = torch.tensor( + voigt_6_to_full_3x3_stress(mask.detach().cpu()), + device=self.device, + ) + elif mask.shape == (3, 3): + self.mask = mask + else: + raise ValueError("shape of mask should be (3,3) or (6,)") + + if isinstance(cell_factor, float): + cell_factor = cell_factor * torch.ones( + (3 * len(batch), 1), requires_grad=False + ) + if cell_factor is None: + cell_factor = self.batch.natoms.repeat_interleave(3).unsqueeze(dim=1) + + self.hydrostatic_strain = hydrostatic_strain + self.constant_volume = constant_volume + self.pressure = scalar_pressure * torch.eye(3, device=self.device) + self.cell_factor = cell_factor + self.stress = None + self._batch_trace = torch.vmap(torch.trace) + self._batch_diag = torch.vmap(lambda x: x * torch.eye(3, device=x.device)) + + @cached_property + def batch_indices(self): + """Get the batch indices specifying which position/force corresponds to which batch. + + We augment this to specify the batch indices for augmented positions and forces. + """ + augmented_batch = torch.repeat_interleave( + torch.arange( + len(self.batch), dtype=self.batch.batch.dtype, device=self.device + ), + 3, + ) + return torch.cat([self.batch.batch, augmented_batch]) + + def deform_grad(self): + """Get the cell deformation matrix""" + return torch.transpose( + torch.linalg.solve(self.orig_cells, self.get_cells()), 1, 2 + ) + + def get_positions(self): + """Get positions and cell deformation gradient.""" + cur_deform_grad = self.deform_grad() + natoms = self.batch.num_nodes + pos = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + dtype=self.batch.pos.dtype, + device=self.device, + ) + + # Augmented positions are the self.atoms.positions but without the applied deformation gradient + pos[:natoms] = torch.linalg.solve( + cur_deform_grad[self.batch.batch, :, :], + self.batch.pos.view(-1, 3, 1), + ).view(-1, 3) + # cell DOFs are the deformation gradient times a scaling factor + pos[natoms:] = self.cell_factor * cur_deform_grad.view(-1, 3) + return pos.cpu().numpy() if self.numpy else pos + + def set_positions(self, positions: torch.Tensor | NDArray): + """Set positions and cell. + + positions has shape (natoms + ncells * 3, 3). + the first natoms rows are the positions of the atoms, the last nsystems * three rows are the deformation tensor + for each cell. + """ + if isinstance(positions, np.ndarray): + positions = torch.tensor(positions) + + positions = positions.to(dtype=torch.float32, device=self.device) + natoms = self.batch.num_nodes + new_atom_positions = positions[:natoms] + new_deform_grad = (positions[natoms:] / self.cell_factor).view(-1, 3, 3) + + # TODO check that in fact symmetry is preserved setting cells and positions + # Set the new cell from the original cell and the new deformation gradient. Both current and final structures + # should preserve symmetry. + new_cells = torch.bmm(self.orig_cells, torch.transpose(new_deform_grad, 1, 2)) + self.set_cells(new_cells) + + # Set the positions from the ones passed in (which are without the deformation gradient applied) and the new + # deformation gradient. This should also preserve symmetry + new_atom_positions = torch.bmm( + new_atom_positions.view(-1, 1, 3), + torch.transpose( + new_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), 1, 2 + ), + ) + super().set_positions(new_atom_positions.view(-1, 3)) + + def get_potential_energy(self, **kwargs): + """ + returns potential energy including enthalpy PV term. + """ + atoms_energy = super().get_potential_energy(**kwargs) + return atoms_energy + self.pressure[0, 0] * self.get_volumes().sum() + + def get_forces( + self, apply_constraint: bool = False, no_numpy: bool = False + ) -> torch.Tensor | NDArray: + """Get forces and unit cell stress.""" + stress = self.get_property("stress", no_numpy=True).view(-1, 3, 3) + atom_forces = self.get_property("forces", no_numpy=True) + + if apply_constraint: + fixed_idx = torch.where(self.batch.fixed == 1)[0] + atom_forces[fixed_idx] = 0.0 + + volumes = self.get_volumes().view(-1, 1, 1) + virial = -volumes * stress + self.pressure.view(-1, 3, 3) + cur_deform_grad = self.deform_grad() + atom_forces = torch.bmm( + atom_forces.view(-1, 1, 3), + cur_deform_grad[self.batch.batch, :, :].view(-1, 3, 3), + ) + virial = torch.linalg.solve( + cur_deform_grad, torch.transpose(virial, dim0=1, dim1=2) + ) + virial = torch.transpose(virial, dim0=1, dim1=2) + + # TODO this does not work yet! maybe _batch_trace gives an issue + if self.hydrostatic_strain: + virial = self._batch_diag(self._batch_trace(virial) / 3.0) + + # Zero out components corresponding to fixed lattice elements + if (self.mask != 1.0).any(): + virial *= self.mask.view(-1, 3, 3) + + if self.constant_volume: + virial[:, range(3), range(3)] -= self._batch_trace(virial).view(3, -1) / 3.0 + + natoms = self.batch.num_nodes + augmented_forces = torch.zeros( + (natoms + 3 * len(self.get_cells()), 3), + device=self.device, + dtype=atom_forces.dtype, + ) + augmented_forces[:natoms] = atom_forces.view(-1, 3) + augmented_forces[natoms:] = virial.view(-1, 3) / self.cell_factor + + self.stress = -virial.view(-1, 9) / volumes.view(-1, 1) + + if self.numpy and not no_numpy: + augmented_forces = augmented_forces.cpu().numpy() + + return augmented_forces + + def __len__(self): + return len(self.batch.pos) + 3 * len(self.batch) diff --git a/src/fairchem/core/common/relaxation/optimizers/__init__.py b/src/fairchem/core/common/relaxation/optimizers/__init__.py index e69de29bb..1c7c27f9f 100644 --- a/src/fairchem/core/common/relaxation/optimizers/__init__.py +++ b/src/fairchem/core/common/relaxation/optimizers/__init__.py @@ -0,0 +1,12 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from .lbfgs_torch import LBFGS + +__all__ = ["LBFGS"] diff --git a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py index a90f0dce5..467c4bec4 100644 --- a/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py +++ b/src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py @@ -16,87 +16,66 @@ import torch from torch_scatter import scatter -from fairchem.core.common.relaxation.ase_utils import batch_to_atoms -from fairchem.core.common.utils import radius_graph_pbc - if TYPE_CHECKING: - from torch_geometric.data import Batch + from .optimizable import OptimizableBatch class LBFGS: + """Limited memory BFGS optimizer for batch ML relaxations.""" + def __init__( self, - batch: Batch, - model: TorchCalc, - maxstep: float = 0.01, + optimizable_batch: OptimizableBatch, + maxstep: float = 0.02, memory: int = 100, - damping: float = 0.25, + damping: float = 1.2, alpha: float = 100.0, - force_consistent=None, - device: str = "cuda:0", save_full_traj: bool = True, traj_dir: Path | None = None, - traj_names=None, - early_stop_batch: bool = False, + traj_names: list[str] | None = None, ) -> None: - self.batch = batch - self.model = model + """ + Args: + optimizable_batch: an optimizable batch which includes a model and a batch of data + maxstep: largest step that any atom is allowed to move + memory: Number of steps to be stored in memory + damping: The calculated step is multiplied with this number before added to the positions. + alpha: Initial guess for the Hessian (curvature of energy surface) + save_full_traj: wether to save full trajectory + traj_dir: path to save trajectories in + traj_names: list of trajectory files names + """ + self.optimizable = optimizable_batch self.maxstep = maxstep self.memory = memory self.damping = damping self.alpha = alpha self.H0 = 1.0 / self.alpha - self.force_consistent = force_consistent - self.device = device self.save_full = save_full_traj self.traj_dir = traj_dir self.traj_names = traj_names - self.early_stop_batch = early_stop_batch - self.otf_graph = True - assert not self.traj_dir or ( - traj_dir and len(traj_names) - ), "Trajectory names should be specified to save trajectories" - logging.info("Step Fmax(eV/A)") - - if not self.otf_graph and "edge_index" not in batch: - self.model.update_graph(self.batch) - - def get_energy_and_forces(self, apply_constraint: bool = True): - energy, forces = self.model.get_energy_and_forces(self.batch, apply_constraint) - return energy, forces - - def set_positions(self, update, update_mask) -> None: - if not self.early_stop_batch: - update = torch.where(update_mask.unsqueeze(1), update, 0.0) - self.batch.pos += update.to(dtype=torch.float32) - - if not self.otf_graph: - self.model.update_graph(self.batch) - - def check_convergence(self, iteration, forces=None, energy=None): - if forces is None or energy is None: - energy, forces = self.get_energy_and_forces() - forces = forces.to(dtype=torch.float64) + self.trajectories = None - max_forces_ = scatter( - (forces**2).sum(axis=1).sqrt(), self.batch.batch, reduce="max" - ) - logging.info( - f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces_.tolist()) - ) + self.fmax = None + self.steps = None - # (batch_size) -> (nAtoms) - max_forces = max_forces_[self.batch.batch] + self.s = deque(maxlen=self.memory) + self.y = deque(maxlen=self.memory) + self.rho = deque(maxlen=self.memory) + self.r0 = None + self.f0 = None - return max_forces.lt(self.fmax), energy, forces + assert not self.traj_dir or ( + traj_dir and len(traj_names) + ), "Trajectory names should be specified to save trajectories" def run(self, fmax, steps): self.fmax = fmax self.steps = steps - self.s = deque(maxlen=self.memory) - self.y = deque(maxlen=self.memory) - self.rho = deque(maxlen=self.memory) + self.s.clear() + self.y.clear() + self.rho.clear() self.r0 = self.f0 = None self.trajectories = None @@ -108,29 +87,33 @@ def run(self, fmax, steps): ] iteration = 0 - converged = False - converged_mask = torch.zeros_like( - self.batch.atomic_numbers, device=self.device - ).bool() - while iteration < steps and not converged: - _converged_mask, energy, forces = self.check_convergence(iteration) - # Models like GemNet-OC can have random noise in their predictions. - # Here we ensure atom positions are not being updated after already - # hitting the desired convergence criteria. - converged_mask = torch.logical_or(converged_mask, _converged_mask) - converged = torch.all(converged_mask) - update_mask = torch.logical_not(converged_mask) + max_forces = self.optimizable.get_max_forces(apply_constraint=True) + logging.info("Step Fmax(eV/A)") + + while iteration < steps and not self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ): + logging.info( + f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist()) + ) if self.trajectories is not None and ( - self.save_full or converged or iteration == steps - 1 or iteration == 0 + self.save_full is True or iteration == 0 ): - self.write(energy, forces, update_mask) - - if not converged and iteration < steps - 1: - self.step(iteration, forces, update_mask) + self.write() + self.step(iteration) + max_forces = self.optimizable.get_max_forces(apply_constraint=True) iteration += 1 + logging.info( + f"{iteration} " + " ".join(f"{x:0.3f}" for x in max_forces.tolist()) + ) + + # save after converged or all iterations ran + if iteration > 0 and self.trajectories is not None: + self.write() + # GPU memory usage as per nvidia-smi seems to gradually build up as # batches are processed. This releases unoccupied cached memory. torch.cuda.empty_cache() @@ -142,102 +125,79 @@ def run(self, fmax, steps): traj_fl = Path(self.traj_dir / f"{name}.traj_tmp", mode="w") traj_fl.rename(traj_fl.with_suffix(".traj")) - self.batch.energy, self.batch.force = self.get_energy_and_forces( - apply_constraint=False - ) - return self.batch + # set predicted values to batch + for name, value in self.optimizable.results.items(): + setattr(self.optimizable.batch, name, value) - def step( - self, - iteration: int, - forces: torch.Tensor | None, - update_mask: torch.Tensor, - ) -> None: - def _batched_dot(x: torch.Tensor, y: torch.Tensor): - return scatter((x * y).sum(dim=-1), self.batch.batch, reduce="sum") - - def determine_step(dr): - steplengths = torch.norm(dr, dim=1) - longest_steps = scatter(steplengths, self.batch.batch, reduce="max") - longest_steps = longest_steps[self.batch.batch] - maxstep = longest_steps.new_tensor(self.maxstep) - scale = (longest_steps + 1e-7).reciprocal() * torch.min( - longest_steps, maxstep - ) - dr *= scale.unsqueeze(1) - return dr * self.damping + return self.optimizable.converged( + forces=None, fmax=self.fmax, max_forces=max_forces + ) - if forces is None: - _, forces = self.get_energy_and_forces() + def determine_step(self, dr): + steplengths = torch.norm(dr, dim=1) + longest_steps = scatter( + steplengths, self.optimizable.batch_indices, reduce="max" + ) + longest_steps = longest_steps[self.optimizable.batch_indices] + maxstep = longest_steps.new_tensor(self.maxstep) + scale = (longest_steps + 1e-7).reciprocal() * torch.min(longest_steps, maxstep) + dr *= scale.unsqueeze(1) + return dr * self.damping + + def _batched_dot(self, x: torch.Tensor, y: torch.Tensor): + return scatter( + (x * y).sum(dim=-1), self.optimizable.batch_indices, reduce="sum" + ) - r = self.batch.pos.clone().to(dtype=torch.float64) + def step(self, iteration: int) -> None: + # cast forces and positions to float64 otherwise the algorithm is prone to overflow + forces = self.optimizable.get_forces(apply_constraint=True).to( + dtype=torch.float64 + ) + pos = self.optimizable.get_positions().to(dtype=torch.float64) # Update s, y, rho if iteration > 0: - s0 = r - self.r0 + s0 = pos - self.r0 self.s.append(s0) y0 = -(forces - self.f0) self.y.append(y0) - self.rho.append(1.0 / _batched_dot(y0, s0)) + self.rho.append(1.0 / self._batched_dot(y0, s0)) loopmax = min(self.memory, iteration) - alpha = forces.new_empty(loopmax, self.batch.natoms.shape[0]) + alpha = forces.new_empty(loopmax, self.optimizable.batch.natoms.shape[0]) q = -forces for i in range(loopmax - 1, -1, -1): - alpha[i] = self.rho[i] * _batched_dot(self.s[i], q) # b - q -= alpha[i][self.batch.batch, ..., None] * self.y[i] + alpha[i] = self.rho[i] * self._batched_dot(self.s[i], q) # b + q -= alpha[i][self.optimizable.batch_indices, ..., None] * self.y[i] z = self.H0 * q for i in range(loopmax): - beta = self.rho[i] * _batched_dot(self.y[i], z) + beta = self.rho[i] * self._batched_dot(self.y[i], z) z += self.s[i] * ( - alpha[i][self.batch.batch, ..., None] - - beta[self.batch.batch, ..., None] + alpha[i][self.optimizable.batch_indices, ..., None] + - beta[self.optimizable.batch_indices, ..., None] ) # descent direction p = -z - dr = determine_step(p) + dr = self.determine_step(p) + if torch.abs(dr).max() < 1e-7: # Same configuration again (maybe a restart): return - self.set_positions(dr, update_mask) - - self.r0 = r + self.optimizable.set_positions(pos + dr) + self.r0 = pos self.f0 = forces - def write(self, energy, forces, update_mask) -> None: - self.batch.energy, self.batch.force = energy, forces - atoms_objects = batch_to_atoms(self.batch) - update_mask_ = torch.split(update_mask, self.batch.natoms.tolist()) - for atm, traj, mask in zip(atoms_objects, self.trajectories, update_mask_): - if mask[0] or not self.save_full: + def write(self) -> None: + atoms_objects = self.optimizable.get_atoms_list() + for atm, traj, mask in zip( + atoms_objects, self.trajectories, self.optimizable.update_mask + ): + if mask: traj.write(atm) - - -class TorchCalc: - def __init__(self, model, transform=None) -> None: - self.model = model - self.transform = transform - - def get_energy_and_forces(self, atoms, apply_constraint: bool = True): - predictions = self.model.predict(atoms, per_image=False, disable_tqdm=True) - energy = predictions["energy"] - forces = predictions["forces"] - if apply_constraint: - fixed_idx = torch.where(atoms.fixed == 1)[0] - forces[fixed_idx] = 0 - return energy, forces - - def update_graph(self, atoms): - edge_index, cell_offsets, num_neighbors = radius_graph_pbc(atoms, 6, 50) - atoms.edge_index = edge_index - atoms.cell_offsets = cell_offsets - atoms.neighbors = num_neighbors - if self.transform is not None: - atoms = self.transform(atoms) - return atoms diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index d688b8e79..ebbe1dfac 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -105,7 +105,7 @@ def __init__( if len(self.ids) == 0: raise ValueError( - rf"No valid ase data found!" + rf"No valid ase data found! \n" f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" ) @@ -142,7 +142,7 @@ def __getitem__(self, idx): data_object = self.transforms(data_object) if self.config.get("include_relaxed_energy", False): - data_object.y_relaxed = self.get_relaxed_energy(self.ids[idx]) + data_object.energy_relaxed = self.get_relaxed_energy(self.ids[idx]) return data_object @@ -160,9 +160,12 @@ def _load_dataset_get_ids(self, config): "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." ) - @abstractmethod def get_relaxed_energy(self, identifier): - raise NotImplementedError("IS2RE-Direct is not implemented with this dataset.") + raise NotImplementedError( + "Reading relaxed energy from trajectory or file is not implemented with this dataset. " + "If relaxed energies are saved with the atoms info dictionary, they can be used by passing the keys in " + "the r_data_keys argument under a2g_args." + ) def sample_property_metadata(self, num_samples: int = 100) -> dict: metadata = {} @@ -568,8 +571,3 @@ def sample_property_metadata(self, num_samples: int = 100) -> dict: return super().sample_property_metadata(num_samples) return copy.deepcopy(self.dbs[0].metadata) - - def get_relaxed_energy(self, identifier): - raise NotImplementedError( - "IS2RE-Direct training with an ASE DB is not currently supported." - ) diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 32865e0ef..e6c3e0820 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -63,14 +63,10 @@ def generate_graph( use_pbc_single = use_pbc_single or self.use_pbc_single otf_graph = otf_graph or self.otf_graph - if enforce_max_neighbors_strictly is not None: - pass - elif hasattr(self, "enforce_max_neighbors_strictly"): - # Not all models will have this attribute - enforce_max_neighbors_strictly = self.enforce_max_neighbors_strictly - else: - # Default to old behavior - enforce_max_neighbors_strictly = True + if enforce_max_neighbors_strictly is None: + enforce_max_neighbors_strictly = getattr( + self, "enforce_max_neighbors_strictly", True + ) if not otf_graph: try: diff --git a/src/fairchem/core/preprocessing/atoms_to_graphs.py b/src/fairchem/core/preprocessing/atoms_to_graphs.py index f4b5a757b..fa679a262 100644 --- a/src/fairchem/core/preprocessing/atoms_to_graphs.py +++ b/src/fairchem/core/preprocessing/atoms_to_graphs.py @@ -250,7 +250,7 @@ def convert(self, atoms: ase.Atoms, sid=None): for data_key in self.r_data_keys: data[data_key] = ( atoms.info[data_key] - if isinstance(atoms.info[data_key], (int, float)) + if isinstance(atoms.info[data_key], (int, float, str)) else torch.Tensor(atoms.info[data_key]) ) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 90cdce0e5..5c21a743a 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -215,6 +215,7 @@ def __init__( self.test_dataset = None self.best_val_metric = None self.primary_metric = None + self.ema = None self.load(inference_only) @@ -361,7 +362,7 @@ def convert_settings_to_split_settings(config, split_name): ) self.train_sampler = self.get_sampler( self.train_dataset, - self.config["optim"]["batch_size"], + self.config["optim"].get("batch_size", 1), shuffle=True, ) self.train_loader = self.get_dataloader( @@ -392,7 +393,7 @@ def convert_settings_to_split_settings(config, split_name): self.val_sampler = self.get_sampler( self.val_dataset, self.config["optim"].get( - "eval_batch_size", self.config["optim"]["batch_size"] + "eval_batch_size", self.config["optim"].get("batch_size", 1) ), shuffle=False, ) @@ -414,7 +415,7 @@ def convert_settings_to_split_settings(config, split_name): self.test_sampler = self.get_sampler( self.test_dataset, self.config["optim"].get( - "eval_batch_size", self.config["optim"]["batch_size"] + "eval_batch_size", self.config["optim"].get("batch_size", 1) ), shuffle=False, ) diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index a8976773c..8e5d17820 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -552,7 +552,7 @@ def predict( return predictions @torch.no_grad - def run_relaxations(self, split="val"): + def run_relaxations(self): ensure_fitted(self._unwrapped_model) # When set to true, uses deterministic CUDA scatter ops, if available. @@ -572,14 +572,14 @@ def run_relaxations(self, split="val"): evaluator_is2rs, metrics_is2rs = Evaluator(task="is2rs"), {} evaluator_is2re, metrics_is2re = Evaluator(task="is2re"), {} - # Need both `pos_relaxed` and `y_relaxed` to compute val IS2R* metrics. + # Need both `pos_relaxed` and `energy_relaxed` to compute val IS2R* metrics. # Else just generate predictions. if ( hasattr(self.relax_dataset[0], "pos_relaxed") and self.relax_dataset[0].pos_relaxed is not None ) and ( - hasattr(self.relax_dataset[0], "y_relaxed") - and self.relax_dataset[0].y_relaxed is not None + hasattr(self.relax_dataset[0], "energy_relaxed") + and self.relax_dataset[0].energy_relaxed is not None ): split = "val" else: @@ -608,9 +608,10 @@ def run_relaxations(self, split="val"): model=self, steps=self.config["task"].get("relaxation_steps", 300), fmax=self.config["task"].get("relaxation_fmax", 0.02), + relax_cell=self.config["task"].get("relax_cell", False), + relax_volume=self.config["task"].get("relax_volume", False), relax_opt=self.config["task"]["relax_opt"], save_full_traj=self.config["task"].get("save_full_traj", True), - device=self.device, transform=None, ) @@ -638,7 +639,7 @@ def run_relaxations(self, split="val"): s_idx += natoms target = { - "energy": relaxed_batch.energy, + "energy": relaxed_batch.energy_relaxed, "positions": relaxed_batch.pos_relaxed[mask], "cell": relaxed_batch.cell, "pbc": torch.tensor([True, True, True]), diff --git a/tests/core/common/conftest.py b/tests/core/common/conftest.py new file mode 100644 index 000000000..6187cbf3a --- /dev/null +++ b/tests/core/common/conftest.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest +from ase import build + +from fairchem.core.common.relaxation.ase_utils import OCPCalculator +from fairchem.core.datasets import data_list_collater +from fairchem.core.preprocessing.atoms_to_graphs import AtomsToGraphs + + +@pytest.fixture(scope="session") +def calculator(tmp_path_factory): + dir = tmp_path_factory.mktemp("checkpoints") + return OCPCalculator( + model_name="EquiformerV2-31M-S2EF-OC20-All+MD", local_cache=dir, seed=0 + ) + + +@pytest.fixture() +def atoms_list(): + atoms_list = [ + build.bulk("Cu", "fcc", a=3.8, cubic=True), + build.bulk("NaCl", crystalstructure="rocksalt", a=5.8), + ] + for atoms in atoms_list: + atoms.rattle(stdev=0.05, seed=0) + return atoms_list + + +@pytest.fixture() +def batch(atoms_list): + a2g = AtomsToGraphs(r_edges=False, r_pbc=True) + return data_list_collater([a2g.convert(atoms) for atoms in atoms_list]) diff --git a/tests/core/common/test_ase_calculator.py b/tests/core/common/test_ase_calculator.py index 3d62c35e1..92baa37cb 100644 --- a/tests/core/common/test_ase_calculator.py +++ b/tests/core/common/test_ase_calculator.py @@ -65,6 +65,9 @@ def test_relaxation_final_energy(atoms, tmp_path, snapshot) -> None: cpu=True, ) + assert "energy" in calc.implemented_properties + assert "forces" in calc.implemented_properties + atoms.set_calculator(calc) opt = BFGS(atoms) opt.run(fmax=0.05, steps=100) diff --git a/tests/core/common/test_lbfgs_torch.py b/tests/core/common/test_lbfgs_torch.py new file mode 100644 index 000000000..7bcf743eb --- /dev/null +++ b/tests/core/common/test_lbfgs_torch.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from itertools import combinations, product + +import numpy as np +import numpy.testing as npt +import pytest +from ase.io import read +from ase.optimize import LBFGS as LBFGS_ASE + +from fairchem.core.common.relaxation import OptimizableBatch +from fairchem.core.common.relaxation.optimizers import LBFGS +from fairchem.core.modules.evaluator import min_diff + + +def test_lbfgs_relaxation(atoms_list, batch, calculator): + """Tests batch relaxation using fairchem LBFGS optimizer.""" + obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False) + + # optimize atoms one-by-one + for atoms in atoms_list: + atoms.calc = calculator + opt = LBFGS_ASE(atoms, damping=0.8, alpha=70.0) + opt.run(0.01, 20) + + # optimize atoms in batch using ASE + batch_optimizer = LBFGS(obatch, damping=0.8, alpha=70.0) + batch_optimizer.run(0.01, 20) + + # compare energy and atom positions, this needs pretty slack tols but that should be ok + for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): + assert a1.get_potential_energy() / len(a1) == pytest.approx( + a2.get_potential_energy() / len(a2), abs=0.05 + ) + diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) + npt.assert_allclose(diff, 0, atol=0.01) + + +@pytest.mark.parametrize( + ("save_full_traj", "steps"), list(product((True, False), (0, 1, 5))) +) +def test_lbfgs_write_trajectory(save_full_traj, steps, batch, calculator, tmp_path): + obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=False) + batch_optimizer = LBFGS( + obatch, + save_full_traj=save_full_traj, + traj_dir=tmp_path, + traj_names=[f"system-{i}" for i in range(len(batch))], + ) + + batch_optimizer.run(0.001, steps=steps) + + # check that trajectory files where written + traj_files = list(tmp_path.glob("*.traj")) + assert len(traj_files) == len(batch) + + traj_length = ( + 0 if steps == 0 else steps + 1 if save_full_traj else 2 + ) # first and final frame + for file in traj_files: + traj = read(file, ":") + assert len(traj) == traj_length + + # make sure all written frames are unique + for a1, a2 in combinations(traj, r=2): + assert not np.allclose(a1.positions, a2.positions, atol=1e-5) diff --git a/tests/core/common/test_optimizable.py b/tests/core/common/test_optimizable.py new file mode 100644 index 000000000..7024d9128 --- /dev/null +++ b/tests/core/common/test_optimizable.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import numpy as np +import numpy.testing as npt +import pytest +from ase.optimize import BFGS, FIRE, LBFGS + +try: + from ase.filters import UnitCellFilter +except ModuleNotFoundError: + # older ase version, import UnitCellFilterOld + from ase.constraints import UnitCellFilter + +from fairchem.core.common.relaxation import OptimizableBatch, OptimizableUnitCellBatch +from fairchem.core.datasets import data_list_collater +from fairchem.core.modules.evaluator import min_diff + + +@pytest.fixture(params=[FIRE, BFGS, LBFGS]) +def optimizer_cls(request): + return request.param + + +def test_ase_relaxation(atoms_list, batch, calculator, optimizer_cls): + """Tests batch relaxation using ASE optimizers.""" + obatch = OptimizableBatch(batch, trainer=calculator.trainer, numpy=True) + + # optimize atoms one-by-one + for atoms in atoms_list: + atoms.calc = calculator + opt = optimizer_cls(atoms) + opt.run(0.01, 20) + + # optimize atoms in batch using ASE + batch_optimizer = optimizer_cls(obatch) + batch_optimizer.run(0.01, 20) + + # compare energy and atom positions, this needs pretty slack tols but that should be ok + for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): + assert a1.get_potential_energy() / len(a1) == pytest.approx( + a2.get_potential_energy() / len(a2), abs=0.05 + ) + diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) + npt.assert_allclose(diff, 0, atol=0.01) + + +@pytest.mark.parametrize("mask_converged", [False, True]) +def test_batch_relaxation_mask(atoms_list, calculator, mask_converged): + """Test that masking is working as intended!""" + # relax only the first atom in list + atoms = atoms_list[0] + atoms.calc = calculator + opt = LBFGS(atoms) + opt.run(0.01, 50) + assert ((atoms.get_forces() ** 2).sum(axis=1) ** 0.5 <= 0.01).all() + + # now create a batch + batch = data_list_collater([calculator.a2g.convert(atoms) for atoms in atoms_list]) + obatch = OptimizableBatch( + batch, trainer=calculator.trainer, numpy=True, mask_converged=mask_converged + ) + + npt.assert_allclose(batch.pos[batch.batch == 0].cpu().numpy(), atoms.positions) + batch_opt = LBFGS(obatch) + batch_opt.run(0.01, 20) + + if mask_converged: + # assert preconverged structure was not changed at all + npt.assert_allclose(batch.pos[batch.batch == 0].cpu().numpy(), atoms.positions) + assert not np.allclose( + batch.pos[batch.batch == 1].cpu().numpy(), atoms_list[1].positions + ) + else: + # assert that it was changed + assert not np.allclose( + batch.pos[batch.batch == 0].cpu().numpy(), atoms.positions + ) + + +@pytest.mark.skip("Skip until we have a test model that can predict stress") +def test_ase_cell_relaxation(atoms_list, batch, calculator, optimizer_cls): + """Tests batch relaxation using ASE optimizers.""" + cell_factor = batch.natoms.cpu().numpy().mean() + obatch = OptimizableUnitCellBatch( + batch, trainer=calculator.trainer, numpy=True, cell_factor=cell_factor + ) + + # optimize atoms in batch using ASE + batch_optimizer = optimizer_cls(obatch) + batch_optimizer.run(0.01, 20) + + # optimize atoms one-by-one + for atoms in atoms_list: + print(atoms.cell.array) + atoms.calc = calculator + opt = optimizer_cls(UnitCellFilter(atoms, cell_factor=cell_factor)) + opt.run(0.01, 20) + + # compare energy, atom positions and cell + for a1, a2 in zip(atoms_list, obatch.get_atoms_list()): + assert a1.get_potential_energy() / len(a1) == pytest.approx( + a2.get_potential_energy() / len(a2), abs=0.05 + ) + diff = min_diff(a1.positions, a2.positions, a1.get_cell(), pbc=a1.pbc) + npt.assert_allclose(diff, 0, atol=0.05, rtol=0.05) + + cnorm1 = np.linalg.norm(a1.cell.array, axis=1) + cnorm2 = np.linalg.norm(a2.cell.array, axis=1) + npt.assert_allclose(cnorm1, cnorm2, atol=0.01, rtol=0.01) + npt.assert_allclose(a1.cell.array.T, a2.cell.array.T, rtol=0.01, atol=0.01) diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 7b114d877..676805c65 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -228,9 +228,9 @@ def test_ase_multiread_dataset(tmp_path): assert len(dataset) == len(atoms_objects) - assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].energy - assert dataset[-1].y_relaxed == dataset[-1].energy + assert hasattr(dataset[0], "energy_relaxed") + assert dataset[0].energy_relaxed != dataset[0].energy + assert dataset[-1].energy_relaxed == dataset[-1].energy dataset = AseReadDataset( config={ @@ -247,8 +247,8 @@ def test_ase_multiread_dataset(tmp_path): } ) - assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].energy + assert hasattr(dataset[0], "energy_relaxed") + assert dataset[0].energy_relaxed != dataset[0].energy def test_empty_dataset(tmp_path): From 816cf009e86de956cecbce31d651b6455d03f4b3 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 3 Dec 2024 17:38:32 -0800 Subject: [PATCH 6/7] pin pytorch at 2_4_0 (#928) * pin pytorch at 2_4_0 * also fix pypi package --------- Co-authored-by: misko user --- packages/env.cpu.yml | 2 +- packages/env.gpu.yml | 2 +- packages/fairchem-core/pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/env.cpu.yml b/packages/env.cpu.yml index 985d8b721..a9cba3700 100644 --- a/packages/env.cpu.yml +++ b/packages/env.cpu.yml @@ -4,7 +4,7 @@ channels: - defaults dependencies: - cpuonly -- pytorch>=2.4 +- pytorch==2.4.0 - ase - e3nn>=0.5 - numpy >=1.26.0,<2.0.0 diff --git a/packages/env.gpu.yml b/packages/env.gpu.yml index 50b0c6231..0e199c055 100644 --- a/packages/env.gpu.yml +++ b/packages/env.gpu.yml @@ -5,7 +5,7 @@ channels: - defaults dependencies: - pytorch-cuda=12.1 -- pytorch>=2.4 +- pytorch==2.4.0 - ase - e3nn>=0.5 - numpy >=1.26.0,<2.0.0 diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index dfd5c671a..5081f624b 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -9,7 +9,7 @@ license = {text = "MIT License"} dynamic = ["version", "readme"] requires-python = ">=3.9, <3.13" dependencies = [ - "torch>=2.4", + "torch==2.4", "numpy >=1.26.0, <2.0.0", "lmdb", "ase", From 6ab6ad72da6f24a5937140a6f4a04891fc351ea6 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Wed, 4 Dec 2024 22:32:11 -0800 Subject: [PATCH 7/7] Add cuda set_device for local run (#931) * add set local device * fix test --------- Co-authored-by: rgao --- src/fairchem/core/common/distutils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index 604f969a8..d1569856a 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -80,7 +80,8 @@ def setup(config) -> None: assign_device_for_local_rank(config["cpu"], config["local_rank"]) else: # in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank - # this is dangerous and should be deprecated + # this is dangerous and should be deprecated, however, FSDP still requires backwards compatibility with + # initializing this way for now so we need to keep it torch.cuda.set_device(config["local_rank"]) dist.init_process_group( @@ -123,6 +124,11 @@ def setup(config) -> None: config["local_rank"] = int(os.environ.get("LOCAL_RANK")) if config.get("use_cuda_visibile_devices"): assign_device_for_local_rank(config["cpu"], config["local_rank"]) + elif torch.cuda.is_available(): + # in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank + # this is dangerous and should be deprecated, however, FSDP still requires backwards compatibility with + # initializing this way for now so we need to keep it + torch.cuda.set_device(config["local_rank"]) dist.init_process_group( backend=config["distributed_backend"], rank=int(os.environ.get("RANK")),