diff --git a/TRAIN.md b/TRAIN.md index d16d08dfb..03719e23d 100644 --- a/TRAIN.md +++ b/TRAIN.md @@ -204,11 +204,11 @@ To train and validate an OC20 IS2RE/S2EF model on total energies instead of adso ```yaml task: - dataset: oc22_lmdb prediction_dtype: float32 ... dataset: + format: oc22_lmdb train: src: data/oc20/s2ef/train normalize_labels: False @@ -308,8 +308,8 @@ For the IS2RE-Total task, the model takes the initial structure as input and pre ```yaml trainer: energy # Use the EnergyTrainer -task: - dataset: oc22_lmdb # Use the OC22LmdbDataset +dataset: + format: oc22_lmdb # Use the OC22LmdbDataset ... ``` You can find examples configuration files in [`configs/oc22/is2re`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/is2re). @@ -321,8 +321,8 @@ The S2EF-Total task takes a structure and predicts the total DFT energy and per- ```yaml trainer: forces # Use the ForcesTrainer -task: - dataset: oc22_lmdb # Use the OC22LmdbDataset +dataset: + format: oc22_lmdb # Use the OC22LmdbDataset ... ``` You can find examples configuration files in [`configs/oc22/s2ef`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/s2ef). @@ -332,8 +332,8 @@ You can find examples configuration files in [`configs/oc22/s2ef`](https://githu Training on OC20 total energies whether independently or jointly with OC22 requires a path to the `oc20_ref` (download link provided below) to be specified in the configuration file. These are necessary to convert OC20 adsorption energies into their corresponding total energies. The following changes in the configuration file capture these changes: ```yaml -task: - dataset: oc22_lmdb +dataset: + format: oc22_lmdb ... dataset: @@ -382,10 +382,8 @@ If your data is already in an [ASE Database](https://databases.fysik.dtu.dk/ase/ To use this dataset, we will just have to change our config files to use the ASE DB Dataset rather than the LMDB Dataset: ```yaml -task: - dataset: ase_db - dataset: + format: ase_db train: src: # The path/address to your ASE DB connect_args: @@ -420,10 +418,8 @@ It is possible to train/predict directly on ASE-readable files. This is only rec This dataset assumes a single structure will be obtained from each file: ```yaml -task: - dataset: ase_read - dataset: + format: ase_read train: src: # The folder that contains ASE-readable files pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif". @@ -443,10 +439,8 @@ dataset: This dataset supports reading files that each contain multiple structure (for example, an ASE .traj file). Using an index file, which tells the dataset how many structures each file contains, is recommended. Otherwise, the dataset is forced to load every file at startup and count the number of structures! ```yaml -task: - dataset: ase_read_multi - dataset: + format: ase_read_multi train: index_file: Filepath to an index file which contains each filename and the number of structures in each file. e.g.: /path/to/relaxation1.traj 200 diff --git a/env.common.yml b/env.common.yml index 7a22ee2ce..c12632b6d 100644 --- a/env.common.yml +++ b/env.common.yml @@ -7,6 +7,7 @@ dependencies: - ase=3.22.1 - black=22.3.0 - e3nn=0.4.4 +- numpy=1.23.5 - matplotlib - numba - orjson diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index 8c47b0ab7..bdc1544d1 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -969,7 +969,12 @@ def check_traj_files(batch, traj_dir) -> bool: if traj_dir is None: return False traj_dir = Path(traj_dir) - traj_files = [traj_dir / f"{id}.traj" for id in batch.sid.tolist()] + sid_list = ( + batch.sid.tolist() + if isinstance(batch.sid, torch.Tensor) + else batch.sid + ) + traj_files = [traj_dir / f"{sid}.traj" for sid in sid_list] return all(fl.exists() for fl in traj_files) @@ -1204,13 +1209,22 @@ def update_config(base_config): are now. Update old configs to fit the new expected structure. """ config = copy.deepcopy(base_config) - config["dataset"]["format"] = config["task"].get("dataset", "lmdb") + + # If config["dataset"]["format"] is missing, get it from the task (legacy location). + # If it is not there either, default to LMDB. + config["dataset"]["format"] = config["dataset"].get( + "format", config["task"].get("dataset", "lmdb") + ) + ### Read task based off config structure, similar to OCPCalculator. if config["task"]["dataset"] in [ "trajectory_lmdb", "lmdb", "trajectory_lmdb_v2", "oc22_lmdb", + "ase_read", + "ase_read_multi", + "ase_db", ]: task = "s2ef" elif config["task"]["dataset"] == "single_point_lmdb": diff --git a/ocpmodels/datasets/_utils.py b/ocpmodels/datasets/_utils.py new file mode 100644 index 000000000..c0c17db08 --- /dev/null +++ b/ocpmodels/datasets/_utils.py @@ -0,0 +1,33 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from torch_geometric.data import Data + + +def rename_data_object_keys( + data_object: Data, key_mapping: dict[str, str] +) -> Data: + """Rename data object keys + + Args: + data_object: data object + key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key} + """ + for _property in key_mapping: + # catch for test data not containing labels + if _property in data_object: + new_property = key_mapping[_property] + if new_property not in data_object: + data_object[new_property] = data_object[_property] + del data_object[_property] + + return data_object diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 9e4d76b43..cdecd82df 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -1,13 +1,22 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + import bisect import copy -import functools -import glob import logging import os import warnings from abc import ABC, abstractmethod +from functools import cache, reduce +from glob import glob from pathlib import Path -from typing import List +from typing import Any, Callable, Optional import ase import numpy as np @@ -16,8 +25,10 @@ from tqdm import tqdm from ocpmodels.common.registry import registry +from ocpmodels.datasets._utils import rename_data_object_keys from ocpmodels.datasets.lmdb_database import LMDBDatabase from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata +from ocpmodels.modules.transforms import DataTransforms from ocpmodels.preprocessing import AtomsToGraphs @@ -65,25 +76,39 @@ class AseAtomsDataset(Dataset, ABC): """ def __init__( - self, config, transform=None, atoms_transform=apply_one_tags + self, + config: dict, + atoms_transform: Callable[ + [ase.Atoms, Any, ...], ase.Atoms + ] = apply_one_tags, ) -> None: self.config = config - a2g_args = config.get("a2g_args", {}) - if a2g_args is None: - a2g_args = {} + a2g_args = config.get("a2g_args", {}) or {} + + # set default to False if not set by user, assuming otf_graph will be used + if "r_edges" not in a2g_args: + a2g_args["r_edges"] = False # Make sure we always include PBC info in the resulting atoms objects a2g_args["r_pbc"] = True self.a2g = AtomsToGraphs(**a2g_args) - self.transform = transform + self.key_mapping = self.config.get("key_mapping", None) + self.transforms = DataTransforms(self.config.get("transforms", {})) + self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): - self.__getitem__ = functools.cache(self.__getitem__) + self.__getitem__ = cache(self.__getitem__) + + self.ids = self._load_dataset_get_ids(config) - self.ids = self.load_dataset_get_ids(config) + if len(self.ids) == 0: + raise ValueError( + rf"No valid ase data found!" + f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" + ) def __len__(self) -> int: return len(self.ids) @@ -91,10 +116,10 @@ def __len__(self) -> int: def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): - return [self[i] for i in range(*idx.indices(len(self.ids)))] + return [self[i] for i in range(*idx.indices(len(self)))] # Get atoms object via derived class method - atoms = self.get_atoms_object(self.ids[idx]) + atoms = self.get_atoms(self.ids[idx]) # Transform atoms object if self.atoms_transform is not None: @@ -103,14 +128,6 @@ def __getitem__(self, idx): ) sid = atoms.info.get("sid", self.ids[idx]) - try: - sid = tensor([sid]) - warnings.warn( - "Supplied sid is not numeric (or missing). Using dataset indices instead." - ) - except: - sid = tensor([idx]) - fid = atoms.info.get("fid", tensor([0])) # Convert to data object @@ -118,42 +135,50 @@ def __getitem__(self, idx): data_object.fid = fid data_object.natoms = len(atoms) - # Transform data object - if self.transform is not None: - data_object = self.transform( - data_object, **self.config.get("transform_args", {}) + if self.key_mapping is not None: + data_object = rename_data_object_keys( + data_object, self.key_mapping ) + # Transform data object + data_object = self.transforms(data_object) + if self.config.get("include_relaxed_energy", False): data_object.y_relaxed = self.get_relaxed_energy(self.ids[idx]) return data_object @abstractmethod - def get_atoms_object(self, identifier): + def get_atoms(self, idx: str | int) -> ase.Atoms: # This function should return an ASE atoms object. raise NotImplementedError( "Returns an ASE atoms object. Derived classes should implement this function." ) @abstractmethod - def load_dataset_get_ids(self, config): + def _load_dataset_get_ids(self, config): # This function should return a list of ids that can be used to index into the database raise NotImplementedError( "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." ) + @abstractmethod + def get_relaxed_energy(self, identifier): + raise NotImplementedError( + "IS2RE-Direct is not implemented with this dataset." + ) + def close_db(self) -> None: # This method is sometimes called by a trainer pass - def guess_target_metadata(self, num_samples: int = 100): + def get_metadata(self, num_samples: int = 100) -> dict: metadata = {} if num_samples < len(self): metadata["targets"] = guess_property_metadata( [ - self.get_atoms_object(self.ids[idx]) + self.get_atoms(self.ids[idx]) for idx in np.random.choice( len(self), size=(num_samples,), replace=False ) @@ -161,17 +186,11 @@ def guess_target_metadata(self, num_samples: int = 100): ) else: metadata["targets"] = guess_property_metadata( - [ - self.get_atoms_object(self.ids[idx]) - for idx in range(len(self)) - ] + [self.get_atoms(self.ids[idx]) for idx in range(len(self))] ) return metadata - def get_metadata(self): - return self.guess_target_metadata() - @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): @@ -196,7 +215,7 @@ class AseReadDataset(AseAtomsDataset): default options will work for most users If you are using this for a training dataset, set - "r_energy":True and/or "r_forces":True as appropriate + "r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate In that case, energy/forces must be in the files you read (ex. OUTCAR) ase_read_args (dict): Keyword arguments for ase.io.read() @@ -213,14 +232,15 @@ class AseReadDataset(AseAtomsDataset): transform_args (dict): Additional keyword arguments for the transform callable + key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used + in the model with the corresponding property as it was named in the dataset. Only need to use if + the name is different. + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. - - transform (callable, optional): Additional preprocessing function for the Data object - """ - def load_dataset_get_ids(self, config) -> List[Path]: + def _load_dataset_get_ids(self, config) -> list[Path]: self.ase_read_args = config.get("ase_read_args", {}) if ":" in self.ase_read_args.get("index", ""): @@ -230,24 +250,26 @@ def load_dataset_get_ids(self, config) -> List[Path]: self.path = Path(config["src"]) if self.path.is_file(): - raise Exception("The specified src is not a directory") + raise ValueError( + f"The specified src is not a directory: {self.config['src']}" + ) if self.config.get("include_relaxed_energy", False): self.relaxed_ase_read_args = copy.deepcopy(self.ase_read_args) self.relaxed_ase_read_args["index"] = "-1" - return list(self.path.glob(f'{config["pattern"]}')) + return list(self.path.glob(f'{config.get("pattern", "*")}')) - def get_atoms_object(self, identifier): + def get_atoms(self, idx: str | int) -> ase.Atoms: try: - atoms = ase.io.read(identifier, **self.ase_read_args) + atoms = ase.io.read(idx, **self.ase_read_args) except Exception as err: - warnings.warn(f"{err} occured for: {identifier}") + warnings.warn(f"{err} occured for: {idx}", stacklevel=2) raise err return atoms - def get_relaxed_energy(self, identifier): + def get_relaxed_energy(self, identifier) -> float: relaxed_atoms = ase.io.read(identifier, **self.relaxed_ase_read_args) return relaxed_atoms.get_potential_energy(apply_constraint=False) @@ -286,7 +308,7 @@ class AseReadMultiStructureDataset(AseAtomsDataset): default options will work for most users If you are using this for a training dataset, set - "r_energy":True and/or "r_forces":True as appropriate + "r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate In that case, energy/forces must be in the files you read (ex. OUTCAR) ase_read_args (dict): Keyword arguments for ase.io.read() @@ -305,24 +327,28 @@ class AseReadMultiStructureDataset(AseAtomsDataset): transform_args (dict): Additional keyword arguments for the transform callable + key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used + in the model with the corresponding property as it was named in the dataset. Only need to use if + the name is different. + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. transform (callable, optional): Additional preprocessing function for the Data object """ - def load_dataset_get_ids(self, config): + def _load_dataset_get_ids(self, config) -> list[str]: self.ase_read_args = config.get("ase_read_args", {}) if not hasattr(self.ase_read_args, "index"): self.ase_read_args["index"] = ":" if config.get("index_file", None) is not None: - f = open(config["index_file"], "r") - index = f.readlines() + with open(config["index_file"], "r") as f: + index = f.readlines() ids = [] for line in index: - filename = line.split(" ")[0] + filename = line.split(" ", maxsplit=1)[0] for i in range(int(line.split(" ")[1])): ids.append(f"{filename} {i}") @@ -330,8 +356,11 @@ def load_dataset_get_ids(self, config): self.path = Path(config["src"]) if self.path.is_file(): - raise Exception("The specified src is not a directory") - filenames = list(self.path.glob(f'{config["pattern"]}')) + raise ValueError( + f"The specified src is not a directory: {self.config['src']}" + ) + + filenames = list(self.path.glob(f'{config.get("pattern", "*")}')) ids = [] @@ -341,65 +370,40 @@ def load_dataset_get_ids(self, config): try: structures = ase.io.read(filename, **self.ase_read_args) except Exception as err: - warnings.warn(f"{err} occured for: {filename}") + warnings.warn(f"{err} occured for: {filename}", stacklevel=2) else: - for i, structure in enumerate(structures): + for i, _ in enumerate(structures): ids.append(f"{filename} {i}") return ids - def get_atoms_object(self, identifier): + def get_atoms(self, idx: str) -> ase.Atoms: try: + identifiers = idx.split(" ") atoms = ase.io.read( - "".join(identifier.split(" ")[:-1]), **self.ase_read_args - )[int(identifier.split(" ")[-1])] + "".join(identifiers[:-1]), **self.ase_read_args + )[int(identifiers[-1])] except Exception as err: - warnings.warn(f"{err} occured for: {identifier}") + warnings.warn(f"{err} occured for: {idx}", stacklevel=2) raise err if "sid" not in atoms.info: - atoms.info["sid"] = "".join(identifier.split(" ")[:-1]) + atoms.info["sid"] = "".join(identifiers[:-1]) if "fid" not in atoms.info: - atoms.info["fid"] = int(identifier.split(" ")[-1]) + atoms.info["fid"] = int(identifiers[-1]) return atoms - def get_metadata(self): + def get_metadata(self, num_samples: int = 100) -> dict: return {} - def get_relaxed_energy(self, identifier): + def get_relaxed_energy(self, identifier) -> float: relaxed_atoms = ase.io.read( "".join(identifier.split(" ")[:-1]), **self.ase_read_args )[-1] return relaxed_atoms.get_potential_energy(apply_constraint=False) -class dummy_list(list): - def __init__(self, max) -> None: - self.max = max - return - - def __len__(self): - return self.max - - def __getitem__(self, idx): - # Handle slicing - if isinstance(idx, slice): - return [self[i] for i in range(*idx.indices(self.max))] - - # Cast idx as int since it could be a tensor index - idx = int(idx) - - # Handle negative indices (referenced from end) - if idx < 0: - idx += self.max - - if 0 <= idx < self.max: - return idx - else: - raise IndexError - - @registry.register_dataset("ase_db") class AseDBDataset(AseAtomsDataset): """ @@ -415,6 +419,7 @@ class AseDBDataset(AseAtomsDataset): - the path an ASE DB, - the connection address of an ASE DB, - a folder with multiple ASE DBs, + - a list of folders with ASE DBs - a glob string to use to find ASE DBs, or - a list of ASE db paths/addresses. If a folder, every file will be attempted as an ASE DB, and warnings @@ -435,7 +440,7 @@ class AseDBDataset(AseAtomsDataset): default options will work for most users If you are using this for a training dataset, set - "r_energy":True and/or "r_forces":True as appropriate + "r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate In that case, energy/forces must be in the database keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need @@ -444,23 +449,34 @@ class AseDBDataset(AseAtomsDataset): atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable - transform_args (dict): Additional keyword arguments for the transform callable + transforms (dict[str, dict]): Dictionary specifying data transforms as {transform_function: config} + where config is a dictionary specifying arguments to the transform_function + + key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used + in the model with the corresponding property as it was named in the dataset. Only need to use if + the name is different. atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. - transform (callable, optional): Additional preprocessing function for the Data object + transform (callable, optional): deprecated? """ - def load_dataset_get_ids(self, config) -> dummy_list: + def _load_dataset_get_ids(self, config: dict) -> list[int]: if isinstance(config["src"], list): - filepaths = config["src"] + if os.path.isdir(config["src"][0]): + filepaths = reduce( + lambda x, y: x + y, + (glob(f"{path}/*") for path in config["src"]), + ) + else: + filepaths = config["src"] elif os.path.isfile(config["src"]): filepaths = [config["src"]] elif os.path.isdir(config["src"]): - filepaths = glob.glob(f'{config["src"]}/*') + filepaths = glob(f'{config["src"]}/*') else: - filepaths = glob.glob(config["src"]) + filepaths = glob(config["src"]) self.dbs = [] @@ -470,7 +486,7 @@ def load_dataset_get_ids(self, config) -> dummy_list: self.connect_db(path, config.get("connect_args", {})) ) except ValueError: - logging.warning( + logging.debug( f"Tried to connect to {path} but it's not an ASE database!" ) @@ -488,6 +504,7 @@ def load_dataset_get_ids(self, config) -> dummy_list: if hasattr(db, "ids") and self.select_args == {}: self.db_ids.append(db.ids) else: + # this is the slow alternative self.db_ids.append( [row.id for row in db.select(**self.select_args)] ) @@ -495,9 +512,16 @@ def load_dataset_get_ids(self, config) -> dummy_list: idlens = [len(ids) for ids in self.db_ids] self._idlen_cumulative = np.cumsum(idlens).tolist() - return dummy_list(sum(idlens)) + return list(range(sum(idlens))) + + def get_atoms(self, idx: int) -> ase.Atoms: + """Get atoms object corresponding to datapoint idx. Useful to read other properties not in data object. + Args: + idx (int): index in dataset - def get_atoms_object(self, idx): + Returns: + atoms: ASE atoms corresponding to datapoint idx + """ # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._idlen_cumulative, idx) @@ -510,35 +534,40 @@ def get_atoms_object(self, idx): atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx]) atoms = atoms_row.toatoms() + # put data back into atoms info if isinstance(atoms_row.data, dict): atoms.info.update(atoms_row.data) return atoms - def connect_db(self, address, connect_args={}): + @staticmethod + def connect_db( + address: str | Path, connect_args: Optional[dict] = None + ) -> ase.db.core.Database: if connect_args is None: connect_args = {} db_type = connect_args.get("type", "extract_from_name") - if db_type == "lmdb" or ( - db_type == "extract_from_name" and address.split(".")[-1] == "lmdb" + if db_type in ("lmdb", "aselmdb") or ( + db_type == "extract_from_name" + and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb") ): return LMDBDatabase(address, readonly=True, **connect_args) - else: - return ase.db.connect(address, **connect_args) + + return ase.db.connect(address, **connect_args) def close_db(self) -> None: for db in self.dbs: if hasattr(db, "close"): db.close() - def get_metadata(self): + def get_metadata(self, num_samples: int = 100) -> dict: logging.warning( "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" ) if self.dbs[0].metadata == {}: - return self.guess_target_metadata() - else: - return copy.deepcopy(self.dbs[0].metadata) + return super().get_metadata(num_samples) + + return copy.deepcopy(self.dbs[0].metadata) def get_relaxed_energy(self, identifier): raise NotImplementedError( diff --git a/ocpmodels/datasets/lmdb_database.py b/ocpmodels/datasets/lmdb_database.py index 214315067..2264d2519 100644 --- a/ocpmodels/datasets/lmdb_database.py +++ b/ocpmodels/datasets/lmdb_database.py @@ -8,9 +8,12 @@ https://gitlab.com/ase/ase/-/blob/master/LICENSE """ +from __future__ import annotations import os +import typing import zlib +from pathlib import Path from typing import Optional import lmdb @@ -18,6 +21,11 @@ import orjson from ase.db.core import Database, now, ops from ase.db.row import AtomsRow +from typing_extensions import Self + +if typing.TYPE_CHECKING: + from ase import Atoms + # These are special keys in the ASE LMDB that hold # metadata and other info @@ -25,12 +33,9 @@ class LMDBDatabase(Database): - def __enter__(self) -> "LMDBDatabase": - return self - def __init__( self, - filename: Optional[str] = None, + filename: Optional[str | Path] = None, create_indices: bool = True, use_lock_file: bool = False, serial: bool = False, @@ -43,7 +48,12 @@ def __init__( arguments, except that we add a readonly flag. """ super().__init__( - filename, create_indices, use_lock_file, serial, *args, **kwargs + Path(filename), + create_indices, + use_lock_file, + serial, + *args, + **kwargs, ) # Add a readonly mode for when we're only training @@ -53,7 +63,7 @@ def __init__( if self.readonly: # Open a new env self.env = lmdb.open( - self.filename, + str(self.filename), subdir=False, meminit=False, map_async=True, @@ -67,7 +77,7 @@ def __init__( else: # Open a new env with write access self.env = lmdb.open( - self.filename, + str(self.filename), map_size=1099511627776 * 2, subdir=False, meminit=False, @@ -77,9 +87,12 @@ def __init__( self.txn = self.env.begin(write=True) # Load all ids based on keys in the DB. + self.ids = [] + self.deleted_ids = [] self._load_ids() - return + def __enter__(self) -> Self: + return self def __exit__(self, exc_type, exc_value, tb) -> None: self.close() @@ -89,7 +102,13 @@ def close(self) -> None: self.txn.commit() self.env.close() - def _write(self, atoms, key_value_pairs, data, id): + def _write( + self, + atoms: Atoms | AtomsRow, + key_value_pairs: dict, + data: Optional[dict], + idx: Optional[int] = None, + ) -> None: Database._write(self, atoms, key_value_pairs, data) mtime = now() @@ -121,40 +140,53 @@ def _write(self, atoms, key_value_pairs, data, id): constraint.todict() for constraint in constraints ] - # json doesn't like Cell objects, so make it a cell + # json doesn't like Cell objects, so make it an array dct["cell"] = np.asarray(dct["cell"]) - if id is None: - nextid = self._get_nextid() - id = nextid - nextid += 1 + if idx is None: + idx = self._nextid + nextid = idx + 1 else: - data = self.txn.get("{id}".encode("ascii")) + data = self.txn.get(f"{idx}".encode("ascii")) assert data is not None - # Add the new entry, then add the id and write the nextid + # Add the new entry self.txn.put( - f"{id}".encode("ascii"), + f"{idx}".encode("ascii"), zlib.compress( orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY) ), ) - self.ids.append(id) - self.txn.put( - "nextid".encode("ascii"), - zlib.compress( - orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY) - ), - ) + # only append if idx is not in ids + if idx not in self.ids: + self.ids.append(idx) + self.txn.put( + "nextid".encode("ascii"), + zlib.compress( + orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + # check if id is in removed ids and remove accordingly + if idx in self.deleted_ids: + self.deleted_ids.remove(idx) + self._write_deleted_ids() - return id + return idx - def delete(self, ids) -> None: - for id in ids: - self.txn.delete(f"{id}".encode("ascii")) - self.ids.remove(id) + def _update( + self, + idx: int, + key_value_pairs: Optional[dict] = None, + data: Optional[dict] = None, + ): + # hack this to play nicely with ASE code + row = self._get_row(idx, include_data=True) + if data is not None or key_value_pairs is not None: + self._write( + atoms=row, idx=idx, key_value_pairs=key_value_pairs, data=data + ) - self.deleted_ids += ids + def _write_deleted_ids(self): self.txn.put( "deleted_ids".encode("ascii"), zlib.compress( @@ -164,29 +196,34 @@ def delete(self, ids) -> None: ), ) - def _get_row(self, id, include_data: bool = True): - if id is None: + def delete(self, ids: list[int]) -> None: + for idx in ids: + self.txn.delete(f"{idx}".encode("ascii")) + self.ids.remove(idx) + + self.deleted_ids += ids + self._write_deleted_ids() + + def _get_row(self, idx: int, include_data: bool = True): + if idx is None: assert len(self.ids) == 1 - id = self.ids[0] - data = self.txn.get(f"{id}".encode("ascii")) + idx = self.ids[0] + data = self.txn.get(f"{idx}".encode("ascii")) if data is not None: dct = orjson.loads(zlib.decompress(data)) else: - raise KeyError(f"Id {id} missing from the database!") + raise KeyError(f"Id {idx} missing from the database!") if not include_data: dct.pop("data", None) - dct["id"] = id + dct["id"] = idx return AtomsRow(dct) def _get_row_by_index(self, index: int, include_data: bool = True): - """Auxiliary function to get the ith entry, rather than - a specific id - """ - id = self.ids[index] - data = self.txn.get(f"{id}".encode("ascii")) + """Auxiliary function to get the ith entry, rather than a specific id""" + data = self.txn.get(f"{self.ids[index]}".encode("ascii")) if data is not None: dct = orjson.loads(zlib.decompress(data)) @@ -202,12 +239,12 @@ def _get_row_by_index(self, index: int, include_data: bool = True): def _select( self, keys, - cmps, + cmps: list[tuple[str, str, str]], explain: bool = False, verbosity: int = 0, - limit=None, + limit: Optional[int] = None, offset: int = 0, - sort=None, + sort: Optional[str] = None, include_data: bool = True, columns: str = "all", ): @@ -215,16 +252,13 @@ def _select( yield {"explain": (0, 0, 0, "scan table")} return - if sort: + if sort is not None: if sort[0] == "-": reverse = True sort = sort[1:] else: reverse = False - def f(row): - return row.get(sort, missing) - rows = [] missing = [] for row in self._select(keys, cmps): @@ -248,10 +282,10 @@ def f(row): cmps = [(key, ops[op], val) for key, op, val in cmps] n = 0 - for id in self.ids: + for idx in self.ids: if n - offset == limit: return - row = self._get_row(id, include_data=False) + row = self._get_row(idx, include_data=include_data) for key in keys: if key not in row: @@ -296,16 +330,14 @@ def metadata(self, dct): ), ) - def _get_nextid(self): + @property + def _nextid(self): """Get the id of the next row to be written""" # Get the nextid nextid_data = self.txn.get("nextid".encode("ascii")) - if nextid_data is not None: - nextid = orjson.loads(zlib.decompress(nextid_data)) - else: - # This db is empty; start at 1! - nextid = 1 - + nextid = ( + orjson.loads(zlib.decompress(nextid_data)) if nextid_data else 1 + ) return nextid def count(self, selection=None, **kwargs) -> int: @@ -334,14 +366,10 @@ def _load_ids(self) -> None: # Load the deleted ids deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) - if deleted_ids_data is None: - self.deleted_ids = [] - else: + if deleted_ids_data is not None: self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) # Reconstruct the full id list self.ids = [ - i - for i in range(1, self._get_nextid()) - if i not in set(self.deleted_ids) + i for i in range(1, self._nextid) if i not in set(self.deleted_ids) ] diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index 93e13ed33..1c7e313ac 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -21,6 +21,7 @@ from ocpmodels.common.registry import registry from ocpmodels.common.typing import assert_is_instance from ocpmodels.common.utils import pyg2_data_transform +from ocpmodels.datasets._utils import rename_data_object_keys from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata from ocpmodels.modules.transforms import DataTransforms @@ -44,11 +45,9 @@ class LmdbDataset(Dataset[T_co]): folder, but lmdb lengths are now calculated directly from the number of keys. Args: config (dict): Dataset configuration - transform (callable, optional): Data transform function. - (default: :obj:`None`) """ - def __init__(self, config, transform=None) -> None: + def __init__(self, config) -> None: super(LmdbDataset, self).__init__() self.config = config @@ -151,13 +150,9 @@ def __getitem__(self, idx: int) -> T_co: data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) if self.key_mapping is not None: - for _property in self.key_mapping: - # catch for test data not containing labels - if _property in data_object: - new_property = self.key_mapping[_property] - if new_property not in data_object: - data_object[new_property] = data_object[_property] - del data_object[_property] + data_object = rename_data_object_keys( + data_object, self.key_mapping + ) data_object = self.transforms(data_object) diff --git a/ocpmodels/datasets/oc22_lmdb_dataset.py b/ocpmodels/datasets/oc22_lmdb_dataset.py index aee0a2f81..347f3d25d 100644 --- a/ocpmodels/datasets/oc22_lmdb_dataset.py +++ b/ocpmodels/datasets/oc22_lmdb_dataset.py @@ -17,6 +17,7 @@ from ocpmodels.common.registry import registry from ocpmodels.common.typing import assert_is_instance as aii from ocpmodels.common.utils import pyg2_data_transform +from ocpmodels.datasets._utils import rename_data_object_keys from ocpmodels.modules.transforms import DataTransforms @@ -198,12 +199,9 @@ def __getitem__(self, idx): data_object[attr] -= lin_energy if self.key_mapping is not None: - for _property in self.key_mapping: - if _property in data_object: - new_property = self.key_mapping[_property] - if new_property not in data_object: - data_object[new_property] = data_object[_property] - del data_object[_property] + data_object = rename_data_object_keys( + data_object, self.key_mapping + ) # to jointly train on oc22+oc20, need to delete these oc20-only attributes # ensure otf_graph=1 in your model configuration diff --git a/ocpmodels/datasets/target_metadata_guesser.py b/ocpmodels/datasets/target_metadata_guesser.py index 844bd8127..0ee58e80a 100644 --- a/ocpmodels/datasets/target_metadata_guesser.py +++ b/ocpmodels/datasets/target_metadata_guesser.py @@ -1,6 +1,7 @@ import logging import numpy as np +from ase.stress import voigt_6_to_full_3x3_stress def uniform_atoms_lengths(atoms_lens) -> bool: @@ -184,6 +185,16 @@ def guess_property_metadata(atoms_list): np.array(atoms.calc.results[key]) for atoms in atoms_list ] + # stress needs to be handled separately in case it was saved in voigt (6, ) notation + # atoms2graphs will always request voigt=False so turn it into full 3x3 + if key == "stress": + target_samples = [ + voigt_6_to_full_3x3_stress(sample) + if sample.shape != (3, 3) + else sample + for sample in target_samples + ] + # Guess the metadata targets[f"{key}"] = guess_target_metadata( atoms_len, target_samples diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 34291f173..f88da439b 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -5,7 +5,9 @@ LICENSE file in the root directory of this source tree. """ -from typing import Optional +from __future__ import annotations + +from typing import Optional, Sequence import ase.db.sqlite import ase.io.trajectory @@ -17,8 +19,8 @@ try: from pymatgen.io.ase import AseAtomsAdaptor -except Exception: - pass +except ImportError: + AseAtomsAdaptor = None try: @@ -45,6 +47,7 @@ class AtomsToGraphs: radius (int or float): Cutoff radius in Angstroms to search for neighbors. r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. r_distances (bool): Return the distances with other properties. Default is False, so the distances will not be returned. r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. @@ -52,12 +55,15 @@ class AtomsToGraphs: Default is True, so the fixed indices will be returned. r_pbc (bool): Return the periodic boundary conditions with other properties. Default is False, so the periodic boundary conditions will not be returned. + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other + properties. Default is None, so no data will be returned as properties. Attributes: max_neigh (int): Maximum number of neighbors to consider. radius (int or float): Cutoff radius in Angstoms to search for neighbors. r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. r_distances (bool): Return the distances with other properties. Default is False, so the distances will not be returned. r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. @@ -65,7 +71,8 @@ class AtomsToGraphs: Default is True, so the fixed indices will be returned. r_pbc (bool): Return the periodic boundary conditions with other properties. Default is False, so the periodic boundary conditions will not be returned. - + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other + properties. Default is None, so no data will be returned as properties. """ def __init__( @@ -78,19 +85,28 @@ def __init__( r_edges: bool = True, r_fixed: bool = True, r_pbc: bool = False, + r_stress: bool = False, + r_data_keys: Optional[Sequence[str]] = None, ) -> None: self.max_neigh = max_neigh self.radius = radius self.r_energy = r_energy self.r_forces = r_forces + self.r_stress = r_stress self.r_distances = r_distances self.r_fixed = r_fixed self.r_edges = r_edges self.r_pbc = r_pbc + self.r_data_keys = r_data_keys def _get_neighbors_pymatgen(self, atoms: ase.Atoms): """Preforms nearest neighbor search and returns edge index, distances, and cell offsets""" + if AseAtomsAdaptor is None: + raise RuntimeError( + "Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed." + ) + struct = AseAtomsAdaptor.get_structure(atoms) _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( r=self.radius, numerical_tol=0, exclude_self=True @@ -129,7 +145,7 @@ def _reshape_features(self, c_index, n_index, n_distance, offsets): return edge_index, edge_distances, cell_offsets def convert(self, atoms: ase.Atoms, sid=None): - """Convert a single atomic stucture to a graph. + """Convert a single atomic structure to a graph. Args: atoms (ase.atoms.Atoms): An ASE atoms object. @@ -177,14 +193,19 @@ def convert(self, atoms: ase.Atoms, sid=None): data.cell_offsets = cell_offsets if self.r_energy: energy = atoms.get_potential_energy(apply_constraint=False) - data.y = energy + data.energy = energy if self.r_forces: forces = torch.Tensor(atoms.get_forces(apply_constraint=False)) - data.force = forces + data.forces = forces + if self.r_stress: + stress = torch.Tensor( + atoms.get_stress(apply_constraint=False, voigt=False) + ) + data.stress = stress if self.r_distances and self.r_edges: data.distances = edge_distances if self.r_fixed: - fixed_idx = torch.zeros(natoms) + fixed_idx = torch.zeros(natoms, dtype=torch.int) if hasattr(atoms, "constraints"): from ase.constraints import FixAtoms @@ -194,6 +215,13 @@ def convert(self, atoms: ase.Atoms, sid=None): data.fixed = fixed_idx if self.r_pbc: data.pbc = torch.tensor(atoms.pbc) + if self.r_data_keys is not None: + for data_key in self.r_data_keys: + data[data_key] = ( + atoms.info[data_key] + if isinstance(atoms.info[data_key], (int, float)) + else torch.Tensor(atoms.info[data_key]) + ) return data diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 86bab1a9d..336628f00 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -411,6 +411,11 @@ def predict( results_file: Optional[str] = None, disable_tqdm: bool = False, ): + if self.is_debug and per_image: + raise FileNotFoundError( + "Predictions require debug mode to be turned off." + ) + ensure_fitted(self._unwrapped_model, warn=True) if distutils.is_master() and not disable_tqdm: @@ -515,10 +520,18 @@ def predict( return predictions ### Get unique system identifiers - sids = batch.sid.tolist() + sids = ( + batch.sid.tolist() + if isinstance(batch.sid, torch.Tensor) + else batch.sid + ) ## Support naming structure for OC20 S2EF if "fid" in batch: - fids = batch.fid.tolist() + fids = ( + batch.fid.tolist() + if isinstance(batch.fid, torch.Tensor) + else batch.fid + ) systemids = [f"{sid}_{fid}" for sid, fid in zip(sids, fids)] else: systemids = [f"{sid}" for sid in sids] @@ -581,7 +594,9 @@ def run_relaxations(self, split="val"): if check_traj_files( batch, self.config["task"]["relax_opt"].get("traj_dir", None) ): - logging.info(f"Skipping batch: {batch.sid.tolist()}") + logging.info( + f"Skipping batch: {batch.sid.tolist() if isinstance(batch.sid, torch.Tensor) else batch.sid}" + ) continue relaxed_batch = ml_relax( @@ -596,7 +611,12 @@ def run_relaxations(self, split="val"): ) if self.config["task"].get("write_pos", False): - systemids = [str(i) for i in relaxed_batch.sid.tolist()] + sid_list = ( + relaxed_batch.sid.tolist() + if isinstance(relaxed_batch.sid, torch.Tensor) + else relaxed_batch.sid + ) + systemids = [str(sid) for sid in sid_list] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] @@ -679,6 +699,7 @@ def run_relaxations(self, split="val"): # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] + gather_results["pos"] = np.concatenate( np.array(gather_results["pos"])[idx] ) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index d1767c978..34dd47411 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -1,6 +1,7 @@ import os import numpy as np +import pytest from ase import build, db from ase.calculators.singlepoint import SinglePointCalculator from ase.io import Trajectory, write @@ -18,234 +19,131 @@ build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), ] for atoms in structures: - calc = SinglePointCalculator(atoms, energy=1, forces=atoms.positions) + calc = SinglePointCalculator( + atoms, + energy=1, + forces=atoms.positions, + # there is an issue with ASE db when writing a db with 3x3 stress it is flattened to (9,) and then + # errors when trying to read it + stress=np.random.random((6,)), + ) atoms.calc = calc - atoms.info["test_extensive_property"] = 3 * len(atoms) + atoms.info["extensive_property"] = 3 * len(atoms) + atoms.info["tensor_property"] = np.random.random((6, 6)) structures[2].set_pbc(True) -def test_ase_read_dataset() -> None: - for i, structure in enumerate(structures): - write( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), f"{i}.cif" - ), - structure, +@pytest.fixture( + scope="function", + params=[ + "db_dataset", + "db_dataset_folder", + "db_dataset_list", + "db_dataset_path_list", + "lmdb_dataset", + "aselmdb_dataset", + ], +) +def ase_dataset(request, tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("dataset") + mult = 1 + a2g_args = { + "r_energy": True, + "r_forces": True, + "r_stress": True, + "r_data_keys": ["extensive_property", "tensor_property"], + } + if request.param == "db_dataset": + with db.connect(tmp_path / "asedb.db") as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + dataset = AseDBDataset( + config={"src": str(tmp_path / "asedb.db"), "a2g_args": a2g_args} ) - - dataset = AseReadDataset( - config={ - "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), - "pattern": "*.cif", - } - ) - - assert len(dataset) == len(structures) - data = dataset[0] - del data - - dataset.close_db() - - for i in range(len(structures)): - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), f"{i}.cif" - ) + elif ( + request.param == "db_dataset_folder" + or request.param == "db_dataset_list" + ): + for db_name in ("asedb1.db", "asedb2.db"): + with db.connect(tmp_path / db_name) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + mult = 2 + src = ( + str(tmp_path) + if request.param == "db_dataset_folder" + else [str(tmp_path / "asedb1.db"), str(tmp_path / "asedb2.db")] ) - - -def test_ase_db_dataset() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ) + dataset = AseDBDataset(config={"src": src, "a2g_args": a2g_args}) + elif request.param == "db_dataset_path_list": + os.mkdir(tmp_path / "dir1") + os.mkdir(tmp_path / "dir2") + + for dir_name in ("dir1", "dir2"): + for db_name in ("asedb1.db", "asedb2.db"): + with db.connect(tmp_path / dir_name / db_name) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + mult = 4 + dataset = AseDBDataset( + config={ + "src": [str(tmp_path / "dir1"), str(tmp_path / "dir2")], + "a2g_args": a2g_args, + } ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) - - assert len(dataset) == len(structures) - data = dataset[0] - - del data - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) - + elif request.param == "lmbd_dataset": + with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) -def test_ase_db_dataset_folder() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb1.db" - ) + dataset = AseDBDataset( + config={"src": str(tmp_path / "asedb.lmdb"), "a2g_args": a2g_args} ) - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb2.db" - ) - ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "./" - ), - } - ) + else: # "aselmbd_dataset" with .aselmdb file extension + with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) - assert len(dataset) == len(structures) * 2 - data = dataset[0] - del data - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) - - -def test_ase_db_dataset_list() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb1.db" - ) - ) - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb2.db" - ) + dataset = AseDBDataset( + config={"src": str(tmp_path / "asedb.lmdb"), "a2g_args": a2g_args} ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": [ - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb1.db" - ), - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb2.db" - ), - ] - } - ) - assert len(dataset) == len(structures) * 2 - data = dataset[0] - del data + return dataset, mult - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) +def test_ase_dataset(ase_dataset): + dataset, mult = ase_dataset + assert len(dataset) == mult * len(structures) + for data in dataset: + assert hasattr(data, "y") + assert data.forces.shape == (data.natoms, 3) + assert data.stress.shape == (3, 3) + assert data.tensor_property.shape == (6, 6) + assert isinstance(data.extensive_property, int) -def test_ase_lmdb_dataset() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ) - ) - except FileNotFoundError: - pass - with LMDBDatabase( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) +def test_ase_read_dataset(tmp_path) -> None: + # unfortunately there is currently no clean (already implemented) way to save atoms.info when saving + # individual structures - so test separately + for i, structure in enumerate(structures): + write(tmp_path / f"{i}.cif", structure) - dataset = AseDBDataset( + dataset = AseReadDataset( config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ), + "src": str(tmp_path), + "pattern": "*.cif", } ) assert len(dataset) == len(structures) data = dataset[0] del data + dataset.close_db() - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) - - -def test_lmdb_metadata_guesser() -> None: - # Cleanup old lmdb in case it's left over from previous tests - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ) - ) - except FileNotFoundError: - pass - - # Write an LMDB - with LMDBDatabase( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) as database: - for i, structure in enumerate(structures): - database.write(structure, data=structure.info) - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ), - } - ) +def test_ase_metadata_guesser(ase_dataset) -> None: + dataset, _ = ase_dataset metadata = dataset.get_metadata() @@ -259,76 +157,32 @@ def test_lmdb_metadata_guesser() -> None: assert metadata["targets"]["forces"]["extensive"] is True assert metadata["targets"]["forces"]["type"] == "per-atom" - # Confirm forces metadata guessed properly - assert ( - metadata["targets"]["info.test_extensive_property"]["extensive"] - is True - ) - assert metadata["targets"]["info.test_extensive_property"]["shape"] == () - assert ( - metadata["targets"]["info.test_extensive_property"]["type"] - == "per-image" - ) - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) - - -def test_ase_metadata_guesser() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ) - ) - except FileNotFoundError: - pass + # Confirm stress metadata guessed properly + assert metadata["targets"]["stress"]["shape"] == (3, 3) + assert metadata["targets"]["stress"]["extensive"] is False + assert metadata["targets"]["stress"]["type"] == "per-image" - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure, data=structure.info) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } + # Confirm extensive_property metadata guessed properly + assert metadata["targets"]["info.extensive_property"]["extensive"] is True + assert metadata["targets"]["info.extensive_property"]["shape"] == () + assert ( + metadata["targets"]["info.extensive_property"]["type"] == "per-image" ) - metadata = dataset.get_metadata() - - # Confirm energy metadata guessed properly - assert metadata["targets"]["energy"]["extensive"] is False - assert metadata["targets"]["energy"]["shape"] == () - assert metadata["targets"]["energy"]["type"] == "per-image" + # Confirm tensor_property metadata guessed properly + assert metadata["targets"]["info.tensor_property"]["extensive"] is False + assert metadata["targets"]["info.tensor_property"]["shape"] == (6, 6) + assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" - # Confirm forces metadata guessed properly - assert metadata["targets"]["forces"]["shape"] == (3,) - assert metadata["targets"]["forces"]["extensive"] is True - assert metadata["targets"]["forces"]["type"] == "per-atom" - # Confirm forces metadata guessed properly - assert ( - metadata["targets"]["info.test_extensive_property"]["extensive"] - is True - ) - assert metadata["targets"]["info.test_extensive_property"]["shape"] == () - assert ( - metadata["targets"]["info.test_extensive_property"]["type"] - == "per-image" - ) +def test_db_add_delete(tmp_path) -> None: + database = db.connect(tmp_path / "asedb.db") + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) + dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) + assert len(dataset) == len(structures) + orig_len = len(dataset) database.delete([1]) @@ -337,55 +191,20 @@ def test_ase_metadata_guesser() -> None: build.bulk("Al"), ] - for i, structure in enumerate(new_structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) - - assert len(dataset) == len(structures) + len(new_structures) - 1 - data = dataset[:] - assert data - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) + for i, atoms in enumerate(new_structures): + database.write(atoms, data=atoms.info) + dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) + assert len(dataset) == orig_len + len(new_structures) - 1 dataset.close_db() -def test_ase_multiread_dataset() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test.traj" - ) - ) - except FileNotFoundError: - pass - - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ) - ) - except FileNotFoundError: - pass - +def test_ase_multiread_dataset(tmp_path) -> None: atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] energies = np.linspace(1, 0, len(atoms_objects)) - traj = Trajectory( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj"), - mode="w", - ) + traj = Trajectory(tmp_path / "test.traj", mode="w") for atoms, energy in zip(atoms_objects, energies): calc = SinglePointCalculator( @@ -396,7 +215,7 @@ def test_ase_multiread_dataset() -> None: dataset = AseReadMultiStructureDataset( config={ - "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), + "src": str(tmp_path), "pattern": "*.traj", "keep_in_memory": True, "atoms_transform_args": { @@ -406,35 +225,19 @@ def test_ase_multiread_dataset() -> None: ) assert len(dataset) == len(atoms_objects) - [dataset[:]] - f = open( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ), - "w", - ) - f.write( - f"{os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test.traj')} {len(atoms_objects)}" - ) - f.close() + with open(tmp_path / "test_index_file", "w") as f: + f.write(f"{tmp_path / 'test.traj'} {len(atoms_objects)}") dataset = AseReadMultiStructureDataset( - config={ - "index_file": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ) - }, + config={"index_file": str(tmp_path / "test_index_file")}, ) assert len(dataset) == len(atoms_objects) - [dataset[:]] dataset = AseReadMultiStructureDataset( config={ - "index_file": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ), + "index_file": str(tmp_path / "test_index_file"), "a2g_args": { "r_energy": True, "r_forces": True, @@ -444,15 +247,14 @@ def test_ase_multiread_dataset() -> None: ) assert len(dataset) == len(atoms_objects) - [dataset[:]] assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].y - assert dataset[-1].y_relaxed == dataset[-1].y + assert dataset[0].y_relaxed != dataset[0].energy + assert dataset[-1].y_relaxed == dataset[-1].energy dataset = AseReadDataset( config={ - "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), + "src": str(tmp_path), "pattern": "*.traj", "ase_read_args": { "index": "0", @@ -465,16 +267,14 @@ def test_ase_multiread_dataset() -> None: } ) - [dataset[:]] - assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].y + assert dataset[0].y_relaxed != dataset[0].energy - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj") - ) - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ) - ) + +def test_empty_dataset(tmp_path): + # raises error on empty dataset + with pytest.raises(ValueError): + AseReadMultiStructureDataset(config={"src": str(tmp_path)}) + + with pytest.raises(ValueError): + AseDBDataset(config={"src": str(tmp_path)}) diff --git a/tests/datasets/test_ase_lmdb.py b/tests/datasets/test_ase_lmdb.py index 29ad95b66..52f0fec64 100644 --- a/tests/datasets/test_ase_lmdb.py +++ b/tests/datasets/test_ase_lmdb.py @@ -1,25 +1,16 @@ -from pathlib import Path - import numpy as np -import tqdm +import pytest from ase import build from ase.calculators.singlepoint import SinglePointCalculator from ase.constraints import FixAtoms +from ase.db.row import AtomsRow from ocpmodels.datasets.lmdb_database import LMDBDatabase -DB_NAME = "ase_lmdb.lmdb" N_WRITES = 100 N_READS = 200 -def cleanup_asedb() -> None: - if Path(DB_NAME).is_file(): - Path(DB_NAME).unlink() - if Path(f"{DB_NAME}-lock").is_file(): - Path(f"{DB_NAME}-lock").unlink() - - test_structures = [ build.molecule("H2O", vacuum=4), build.bulk("Cu"), @@ -61,110 +52,77 @@ def generate_random_structure(): return slab -def write_random_atoms() -> None: - slab = build.fcc111("Cu", size=(4, 4, 3), vacuum=10.0) - with LMDBDatabase(DB_NAME) as db: +@pytest.fixture(scope="function") +def ase_lmbd_path(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("dataset") + with LMDBDatabase(tmp_path / "ase_lmdb.lmdb") as db: for structure in test_structures: db.write(structure) - for i in tqdm.tqdm(range(N_WRITES)): + for _ in range(N_WRITES): slab = generate_random_structure() - # Save the slab info, and make sure the info gets put in as data db.write(slab, data=slab.info) + return tmp_path / "ase_lmdb.lmdb" -def test_aselmdb_write() -> None: - # Representative structure - write_random_atoms() - - with LMDBDatabase(DB_NAME, readonly=True) as db: +def test_aselmdb_write(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: for i, structure in enumerate(test_structures): assert str(structure) == str(db._get_row_by_index(i).toatoms()) - cleanup_asedb() - - -def test_aselmdb_count() -> None: - # Representative structure - write_random_atoms() - with LMDBDatabase(DB_NAME, readonly=True) as db: +def test_aselmdb_count(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: assert db.count() == N_WRITES + len(test_structures) - cleanup_asedb() - - -def test_aselmdb_delete() -> None: - cleanup_asedb() - # Representative structure - write_random_atoms() - - with LMDBDatabase(DB_NAME) as db: +def test_aselmdb_delete(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: for i in range(5): # Note the available ids list is updating # but the ids themselves are fixed. db.delete([db.ids[0]]) - assert db.count() == N_WRITES + len(test_structures) - 5 - cleanup_asedb() - -def test_aselmdb_randomreads() -> None: - write_random_atoms() - - with LMDBDatabase(DB_NAME, readonly=True) as db: - for i in tqdm.tqdm(range(N_READS)): +def test_aselmdb_randomreads(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: + for _ in range(N_READS): total_size = db.count() - row = db._get_row_by_index(np.random.choice(total_size)).toatoms() - del row - cleanup_asedb() - + assert isinstance( + db._get_row_by_index(np.random.choice(total_size)), AtomsRow + ) -def test_aselmdb_constraintread() -> None: - write_random_atoms() - with LMDBDatabase(DB_NAME, readonly=True) as db: +def test_aselmdb_constraintread(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: atoms = db._get_row_by_index(2).toatoms() - assert type(atoms.constraints[0]) == FixAtoms + assert isinstance(atoms.constraints[0], FixAtoms) - cleanup_asedb() - -def update_keyvalue_pair() -> None: - write_random_atoms() - with LMDBDatabase(DB_NAME) as db: +def test_update_keyvalue_pair(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: db.update(1, test=5) - with LMDBDatabase(DB_NAME) as db: - row = db.get_row_by_id(1) + with LMDBDatabase(ase_lmbd_path) as db: + row = db._get_row(1) assert row.test == 5 - cleanup_asedb() - -def update_atoms() -> None: - write_random_atoms() - with LMDBDatabase(DB_NAME) as db: +def test_update_atoms(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: db.update(40, atoms=test_structures[-1]) - with LMDBDatabase(DB_NAME) as db: - row = db.get_row_by_id(40) + with LMDBDatabase(ase_lmbd_path) as db: + row = db._get_row(40) assert str(row.toatoms()) == str(test_structures[-1]) - cleanup_asedb() - -def test_metadata() -> None: - write_random_atoms() - - with LMDBDatabase(DB_NAME) as db: +def test_metadata(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: db.metadata = {"test": True} - with LMDBDatabase(DB_NAME, readonly=True) as db: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: assert db.metadata["test"] is True - - cleanup_asedb() diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py new file mode 100644 index 000000000..d1367c011 --- /dev/null +++ b/tests/datasets/test_utils.py @@ -0,0 +1,18 @@ +import pytest +import torch +from torch_geometric.data import Data + +from ocpmodels.datasets._utils import rename_data_object_keys + + +@pytest.fixture() +def pyg_data(): + return Data(rand_tensor=torch.rand((3, 3))) + + +def test_rename_data_object_keys(pyg_data): + assert "rand_tensor" in pyg_data.keys + key_mapping = {"rand_tensor": "random_tensor"} + pyg_data = rename_data_object_keys(pyg_data, key_mapping) + assert "rand_tensor" not in pyg_data.keys + assert "random_tensor" in pyg_data.keys diff --git a/tests/preprocessing/atoms.json b/tests/preprocessing/atoms.json index 97c6c4730..86d47cf6b 100644 --- a/tests/preprocessing/atoms.json +++ b/tests/preprocessing/atoms.json @@ -1,20 +1,21 @@ {"1": { "calculator": "unknown", "calculator_parameters": {}, - "cell": {"array": {"__ndarray__": [[3, 3], "float64", [0.0, -8.07194878, 0.0, 6.93127032, 0.0, 0.08307657, 0.0, 0.0, 39.37850739]]}, "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, "__ase_objtype__": "cell"}, + "cell": {"array": {"__ndarray__": [[3, 3], "float64", [0.0, -8.07194878, 0.0, 6.93127032, 0.0, 0.08307657, 0.0, 0.0, 39.37850739]]}, "__ase_objtype__": "cell"}, "constraints": [{"name": "FixAtoms", "kwargs": {"indices": [2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 30, 31, 33]}}], - "ctime": 20.460198850701047, + "ctime": 24.049558230324397, "energy": -135.66393572, "forces": {"__ndarray__": [[34, 3], "float64", [0.05011766, -0.01973735, 0.23846654, -0.12013861, -0.05240431, -0.22395961, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10578597, 0.01361956, -0.05699137, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03172177, 0.00066391, -0.01049754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00908246, -0.09729627, 0.00726873, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02260358, -0.09508909, -0.01036104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03928853, -0.04423657, 0.04053315, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.02912151, 0.05899768, -0.01100117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09680946, 0.06950572, 0.05602877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03057741, 0.10594487, -0.04712197, 0.0, 0.0, 0.0]]}, "initial_charges": {"__ndarray__": [[34], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "initial_magmoms": {"__ndarray__": [[34], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "momenta": {"__ndarray__": [[34, 3], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, - "mtime": 20.460198850701047, + "mtime": 24.049558230324397, "numbers": {"__ndarray__": [[34], "int64", [6, 8, 13, 13, 13, 13, 13, 13, 13, 13, 29, 29, 29, 29, 29, 29, 29, 29, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34]]}, "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, "positions": {"__ndarray__": [[34, 3], "float64", [-0.3289066593614256, -3.0340615866893037, 27.073342845551938, -0.0750331499077992, -2.8712314914365584, 28.205836912191387, 6.2092629718957655, -4.771209055418616, 21.953210855443853, 3.8988395550000003, -0.735234665418617, 18.643976120697392, 1.636610785518665, -1.2302542698255066, 23.72823397486728, 1.5884161381042343, -4.771209055418616, 15.334741779736007, 2.7436278118957658, -6.789196250418616, 21.91167257044385, 0.433204395, -2.7532218604186167, 18.602437835697394, 5.33707967127947, -3.0430981333485136, 25.502246117362063, 5.054051298104235, -6.789196250418616, 15.376280064736006, 3.8988395550000003, -4.771209055418616, 18.643976120697392, 1.5884161381042343, -0.735234665418617, 15.334741779736007, 6.2092629718957655, -0.735234665418617, 21.953210855443853, 1.7024669335227842, -4.898430878701221, 24.462466125364735, 2.7436278118957658, -2.7532218604186167, 21.91167257044385, 0.433204395, -6.789196250418616, 18.602437835697394, 5.0596241087542175, -7.073912126493459, 24.329534869886448, 5.054051298104235, -2.7532218604186167, 15.376280064736006, 1.5841717747237825, -4.763794809025211, 17.789819163977032, 6.205018677828017, -0.7278204190252113, 14.563661393015645, 3.8945952609322516, -0.7278204190252113, 21.09905389955426, 6.2730609484910635, -5.008717107687484, 24.37936591790035, 5.049806934723782, -6.796610416092535, 17.831357448977034, 2.739383517828017, -2.7606360260925347, 14.522123108015645, 0.4289601009322512, -2.7606360260925347, 21.05751561455426, 2.7016609108638554, -7.122213699359126, 24.33216256212159, 5.058295592171984, -2.7458076140252117, 17.84351570962914, 2.747872175276218, -6.781782004025211, 14.534281368667754, 0.43744868906774886, -6.781782004025211, 21.069673874375603, 3.0271987649116516, -2.983072135599385, 24.66107410517354, 3.903083849067749, -4.778623221092535, 21.111212159375604, 1.5926604321719833, -0.7426488310925348, 17.801977424629143, 6.319541839318875, -0.99856463967624, 24.661108015400288, 6.213507335276218, -4.778623221092535, 14.575819653667754]]}, + "stress": {"__ndarray__": [[3, 3], "float64", [-0.02096864, 0.0, 0.0, 0.0, -0.02096864, 0.0, 0.0, 0.0, -0.02096864]]}, "tags": {"__ndarray__": [[34], "int64", [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}, - "unique_id": "77df5102462860280bfa6b622c880125", - "user": "bwood"}, + "unique_id": "e0e1e88b155869bc277be4b4733ed932", + "user": "lbluque"}, "ids": [1], "nextid": 2} diff --git a/tests/preprocessing/test_atoms_to_graphs.py b/tests/preprocessing/test_atoms_to_graphs.py index b0a035a16..58c90bbf1 100644 --- a/tests/preprocessing/test_atoms_to_graphs.py +++ b/tests/preprocessing/test_atoms_to_graphs.py @@ -22,12 +22,25 @@ def atoms_to_graphs_internals(request) -> None: index=0, format="json", ) + atoms.info["stiffness_tensor"] = np.array( + [ + [293, 121, 121, 0, 0, 0], + [121, 293, 121, 0, 0, 0], + [121, 121, 293, 0, 0, 0], + [0, 0, 0, 146, 0, 0], + [0, 0, 0, 0, 146, 0], + [0, 0, 0, 0, 0, 146], + ], + dtype=float, + ) test_object = AtomsToGraphs( max_neigh=200, radius=6, r_energy=True, r_forces=True, + r_stress=True, r_distances=True, + r_data_keys=["stiffness_tensor"], ) request.cls.atg = test_object request.cls.atoms = atoms @@ -100,12 +113,21 @@ def test_convert(self) -> None: np.testing.assert_allclose(act_positions, positions) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) - test_energy = data.y + test_energy = data.energy np.testing.assert_equal(act_energy, test_energy) # forces act_forces = self.atoms.get_forces(apply_constraint=False) - forces = data.force.numpy() + forces = data.forces.numpy() np.testing.assert_allclose(act_forces, forces) + # stress + act_stress = self.atoms.get_stress(apply_constraint=False, voigt=False) + stress = data.stress.numpy() + np.testing.assert_allclose(act_stress, stress) + # additional data (ie stiffness_tensor) + stiffness_tensor = data.stiffness_tensor.numpy() + np.testing.assert_allclose( + self.atoms.info["stiffness_tensor"], stiffness_tensor + ) def test_convert_all(self) -> None: # run convert_all on a list with one atoms object @@ -123,9 +145,18 @@ def test_convert_all(self) -> None: np.testing.assert_allclose(act_positions, positions) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) - test_energy = data_list[0].y + test_energy = data_list[0].energy np.testing.assert_equal(act_energy, test_energy) # forces act_forces = self.atoms.get_forces(apply_constraint=False) - forces = data_list[0].force.numpy() + forces = data_list[0].forces.numpy() np.testing.assert_allclose(act_forces, forces) + # stress + act_stress = self.atoms.get_stress(apply_constraint=False, voigt=False) + stress = data_list[0].stress.numpy() + np.testing.assert_allclose(act_stress, stress) + # additional data (ie stiffness_tensor) + stiffness_tensor = data_list[0].stiffness_tensor.numpy() + np.testing.assert_allclose( + self.atoms.info["stiffness_tensor"], stiffness_tensor + )