Skip to content

Commit

Permalink
Ase dataset updates (#622)
Browse files Browse the repository at this point in the history
* minor cleanup of lmbddatabase

* ase dataset compat for unified trainer and cleanup

* typo in docstring

* key_mapping docstring

* add stress to atoms_to_graphs.py and test

* allow adding target properties in atoms.info

* test using generic tensor property in ase_datasets

* minor docstring/comments

* handle stress in voigt notation in metadata guesser

* handle scalar generic values in a2g

* clean up ase dataset unit tests

* allow .aselmdb extensions

* fix minor bugs in lmdb database and update tests

* make connect_db staticmethod

* remove redundant methods and make some private

* allow a list of paths in AseDBdataset

* remove sprinkled print statement

* remove deprecated transform kwarg

* fix doctring typo

* rename keys function

* fix missing comma in tests

* set default r_edges in a2g in AseDatasets to false

* simple unit-test for good measure

* call _get_row directly

* [wip] allow string sids

* raise a helpful error if AseAtomsAdaptor not available

* remove db extension in filepaths

* set logger to info level when trying to read non db files, remove print

* set logging.debug to avoid saturating logs

* Update documentation for dataset config changes

This PR is intended to address #629

* Update atoms_to_graphs.py

* Update test_ase_datasets.py

* Update test_ase_datasets.py

* Update test_atoms_to_graphs.py

* Update test_atoms_to_graphs.py

* case for explicit a2g_args None values

* Update update_config()

* Update utils.py

* Update utils.py

* Update ocp_trainer.py

More helpful warning for debug mode

* Update ocp_trainer.py

* Update ocp_trainer.py

* Update TRAIN.md

* fix concatenating predictions

* check if keys exist in atoms.info

* Update test_ase_datasets.py

* use list() to cast all batch.sid/fid

* correctly stack predictions

* raise error on empty datasets

* raise ValueError instead of exception

* code cleanup

* rename get_atoms object -> get_atoms for brevity

* revert to raise keyerror when data_keys are missing

* cast tensors to list using tolist and vstack relaxation pos

* remove r_energy, r_forces, r_stress and r_data_keys from test_dataset w use_train_settings

* fix test_dataset key

* fix test_dataset key!

* revert to not setting a2g_args dataset keys

* fix debug predict logic

* support numpy 1.26

* fix numpy version

* revert write_pos

* no list casting on batch lists

* pretty logging

---------

Co-authored-by: Ethan Sunshine <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
  • Loading branch information
3 people authored Apr 1, 2024
1 parent fa39a8f commit f6e46b1
Show file tree
Hide file tree
Showing 16 changed files with 604 additions and 644 deletions.
26 changes: 10 additions & 16 deletions TRAIN.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ To train and validate an OC20 IS2RE/S2EF model on total energies instead of adso

```yaml
task:
dataset: oc22_lmdb
prediction_dtype: float32
...
dataset:
format: oc22_lmdb
train:
src: data/oc20/s2ef/train
normalize_labels: False
Expand Down Expand Up @@ -308,8 +308,8 @@ For the IS2RE-Total task, the model takes the initial structure as input and pre
```yaml
trainer: energy # Use the EnergyTrainer
task:
dataset: oc22_lmdb # Use the OC22LmdbDataset
dataset:
format: oc22_lmdb # Use the OC22LmdbDataset
...
```
You can find examples configuration files in [`configs/oc22/is2re`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/is2re).
Expand All @@ -321,8 +321,8 @@ The S2EF-Total task takes a structure and predicts the total DFT energy and per-
```yaml
trainer: forces # Use the ForcesTrainer
task:
dataset: oc22_lmdb # Use the OC22LmdbDataset
dataset:
format: oc22_lmdb # Use the OC22LmdbDataset
...
```
You can find examples configuration files in [`configs/oc22/s2ef`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/s2ef).
Expand All @@ -332,8 +332,8 @@ You can find examples configuration files in [`configs/oc22/s2ef`](https://githu
Training on OC20 total energies whether independently or jointly with OC22 requires a path to the `oc20_ref` (download link provided below) to be specified in the configuration file. These are necessary to convert OC20 adsorption energies into their corresponding total energies. The following changes in the configuration file capture these changes:

```yaml
task:
dataset: oc22_lmdb
dataset:
format: oc22_lmdb
...
dataset:
Expand Down Expand Up @@ -382,10 +382,8 @@ If your data is already in an [ASE Database](https://databases.fysik.dtu.dk/ase/
To use this dataset, we will just have to change our config files to use the ASE DB Dataset rather than the LMDB Dataset:

```yaml
task:
dataset: ase_db
dataset:
format: ase_db
train:
src: # The path/address to your ASE DB
connect_args:
Expand Down Expand Up @@ -420,10 +418,8 @@ It is possible to train/predict directly on ASE-readable files. This is only rec
This dataset assumes a single structure will be obtained from each file:

```yaml
task:
dataset: ase_read
dataset:
format: ase_read
train:
src: # The folder that contains ASE-readable files
pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif".
Expand All @@ -443,10 +439,8 @@ dataset:
This dataset supports reading files that each contain multiple structure (for example, an ASE .traj file). Using an index file, which tells the dataset how many structures each file contains, is recommended. Otherwise, the dataset is forced to load every file at startup and count the number of structures!

```yaml
task:
dataset: ase_read_multi
dataset:
format: ase_read_multi
train:
index_file: Filepath to an index file which contains each filename and the number of structures in each file. e.g.:
/path/to/relaxation1.traj 200
Expand Down
1 change: 1 addition & 0 deletions env.common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- ase=3.22.1
- black=22.3.0
- e3nn=0.4.4
- numpy=1.23.5
- matplotlib
- numba
- orjson
Expand Down
18 changes: 16 additions & 2 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,12 @@ def check_traj_files(batch, traj_dir) -> bool:
if traj_dir is None:
return False
traj_dir = Path(traj_dir)
traj_files = [traj_dir / f"{id}.traj" for id in batch.sid.tolist()]
sid_list = (
batch.sid.tolist()
if isinstance(batch.sid, torch.Tensor)
else batch.sid
)
traj_files = [traj_dir / f"{sid}.traj" for sid in sid_list]
return all(fl.exists() for fl in traj_files)


Expand Down Expand Up @@ -1204,13 +1209,22 @@ def update_config(base_config):
are now. Update old configs to fit the new expected structure.
"""
config = copy.deepcopy(base_config)
config["dataset"]["format"] = config["task"].get("dataset", "lmdb")

# If config["dataset"]["format"] is missing, get it from the task (legacy location).
# If it is not there either, default to LMDB.
config["dataset"]["format"] = config["dataset"].get(
"format", config["task"].get("dataset", "lmdb")
)

### Read task based off config structure, similar to OCPCalculator.
if config["task"]["dataset"] in [
"trajectory_lmdb",
"lmdb",
"trajectory_lmdb_v2",
"oc22_lmdb",
"ase_read",
"ase_read_multi",
"ase_db",
]:
task = "s2ef"
elif config["task"]["dataset"] == "single_point_lmdb":
Expand Down
33 changes: 33 additions & 0 deletions ocpmodels/datasets/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import typing

if typing.TYPE_CHECKING:
from torch_geometric.data import Data


def rename_data_object_keys(
data_object: Data, key_mapping: dict[str, str]
) -> Data:
"""Rename data object keys
Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = key_mapping[_property]
if new_property not in data_object:
data_object[new_property] = data_object[_property]
del data_object[_property]

return data_object
Loading

0 comments on commit f6e46b1

Please sign in to comment.