Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ase dataset updates #622

Merged
merged 71 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 62 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
826598f
minor cleanup of lmbddatabase
lbluque Jan 17, 2024
324a645
ase dataset compat for unified trainer and cleanup
lbluque Jan 18, 2024
6bb3b81
typo in docstring
lbluque Jan 18, 2024
b4614c4
key_mapping docstring
lbluque Jan 19, 2024
d736b00
add stress to atoms_to_graphs.py and test
lbluque Jan 19, 2024
0a17008
allow adding target properties in atoms.info
lbluque Jan 19, 2024
3a7f810
test using generic tensor property in ase_datasets
lbluque Jan 23, 2024
f47a0b8
minor docstring/comments
lbluque Jan 23, 2024
c2a789e
handle stress in voigt notation in metadata guesser
lbluque Jan 23, 2024
47f4578
handle scalar generic values in a2g
lbluque Jan 24, 2024
48dc7d0
clean up ase dataset unit tests
lbluque Jan 24, 2024
8549411
allow .aselmdb extensions
lbluque Jan 25, 2024
3371cae
fix minor bugs in lmdb database and update tests
lbluque Jan 25, 2024
a0a2b2e
make connect_db staticmethod
lbluque Jan 25, 2024
237f000
remove redundant methods and make some private
lbluque Jan 25, 2024
cae0765
allow a list of paths in AseDBdataset
lbluque Jan 26, 2024
dd0b5fc
remove sprinkled print statement
lbluque Jan 26, 2024
303120a
remove deprecated transform kwarg
lbluque Jan 29, 2024
56df36d
fix doctring typo
lbluque Jan 29, 2024
597e421
rename keys function
lbluque Jan 29, 2024
11bd455
fix missing comma in tests
lbluque Jan 29, 2024
07f2172
set default r_edges in a2g in AseDatasets to false
lbluque Jan 29, 2024
d99d383
simple unit-test for good measure
lbluque Jan 29, 2024
18fd2f1
call _get_row directly
lbluque Jan 31, 2024
fd30b43
[wip] allow string sids
lbluque Feb 1, 2024
77a40dd
raise a helpful error if AseAtomsAdaptor not available
lbluque Feb 2, 2024
c441734
remove db extension in filepaths
lbluque Feb 9, 2024
5b13296
set logger to info level when trying to read non db files, remove print
lbluque Feb 16, 2024
242b54f
set logging.debug to avoid saturating logs
lbluque Feb 17, 2024
6c678f1
Update documentation for dataset config changes
emsunshine Feb 26, 2024
fd4d3e8
Update atoms_to_graphs.py
emsunshine Feb 26, 2024
61ffef3
Update test_ase_datasets.py
emsunshine Feb 26, 2024
e3ea559
Update test_ase_datasets.py
emsunshine Feb 26, 2024
21ccf6a
Update test_atoms_to_graphs.py
emsunshine Feb 26, 2024
b8a4c2f
Update test_atoms_to_graphs.py
emsunshine Feb 26, 2024
d0cf20b
Merge branch 'main' into ase_data_updates
lbluque Feb 26, 2024
ec17ce8
case for explicit a2g_args None values
lbluque Feb 27, 2024
8b3cfac
Merge remote-tracking branch 'origin/ase_data_updates' into ase_data_…
lbluque Feb 27, 2024
01863dd
Update update_config()
emsunshine Feb 27, 2024
1c5ca26
Update utils.py
emsunshine Feb 27, 2024
90a6f6e
Update utils.py
emsunshine Feb 27, 2024
885deba
Update ocp_trainer.py
emsunshine Feb 27, 2024
0903f03
Update ocp_trainer.py
emsunshine Feb 27, 2024
17ca6a9
Update ocp_trainer.py
emsunshine Feb 27, 2024
c4ca1b0
Update TRAIN.md
emsunshine Feb 27, 2024
1fdc538
Merge branch 'main' into dataset-config-changes-documentation
emsunshine Feb 27, 2024
ce52b2f
fix concatenating predictions
lbluque Feb 27, 2024
5741907
check if keys exist in atoms.info
lbluque Feb 27, 2024
7f7c0b4
Merge branch 'ase_data_updates' into dataset-config-changes-documenta…
emsunshine Feb 28, 2024
068b053
Update test_ase_datasets.py
emsunshine Feb 28, 2024
987ba9f
use list() to cast all batch.sid/fid
lbluque Mar 5, 2024
3b4ad43
Merge pull request #630 from Open-Catalyst-Project/dataset-config-cha…
lbluque Mar 5, 2024
7995b5e
correctly stack predictions
lbluque Mar 6, 2024
3b6e2f9
Merge branch 'main' into ase_data_updates
lbluque Mar 12, 2024
f0982bb
raise error on empty datasets
lbluque Mar 19, 2024
56531d7
raise ValueError instead of exception
lbluque Mar 19, 2024
b9e758d
code cleanup
lbluque Mar 19, 2024
f6bb5d5
rename get_atoms object -> get_atoms for brevity
lbluque Mar 19, 2024
cdc509a
merge upstream
lbluque Mar 22, 2024
2f6ac22
revert to raise keyerror when data_keys are missing
lbluque Mar 22, 2024
b426842
cast tensors to list using tolist and vstack relaxation pos
lbluque Mar 22, 2024
0709e46
remove r_energy, r_forces, r_stress and r_data_keys from test_dataset…
lbluque Mar 22, 2024
310468d
fix test_dataset key
lbluque Mar 23, 2024
2422bb9
fix test_dataset key!
lbluque Mar 23, 2024
3f2f4bb
revert to not setting a2g_args dataset keys
lbluque Mar 26, 2024
ac3c1c3
fix debug predict logic
mshuaibii Mar 26, 2024
a4087a7
support numpy 1.26
mshuaibii Mar 28, 2024
07ea92f
fix numpy version
mshuaibii Mar 28, 2024
47f47e2
revert write_pos
mshuaibii Mar 28, 2024
ca9dbaf
no list casting on batch lists
lbluque Mar 28, 2024
bdbba48
pretty logging
lbluque Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions TRAIN.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ To train and validate an OC20 IS2RE/S2EF model on total energies instead of adso

```yaml
task:
dataset: oc22_lmdb
prediction_dtype: float32
...

dataset:
format: oc22_lmdb
train:
src: data/oc20/s2ef/train
normalize_labels: False
Expand Down Expand Up @@ -308,8 +308,8 @@ For the IS2RE-Total task, the model takes the initial structure as input and pre
```yaml
trainer: energy # Use the EnergyTrainer

task:
dataset: oc22_lmdb # Use the OC22LmdbDataset
dataset:
format: oc22_lmdb # Use the OC22LmdbDataset
...
```
You can find examples configuration files in [`configs/oc22/is2re`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/is2re).
Expand All @@ -321,8 +321,8 @@ The S2EF-Total task takes a structure and predicts the total DFT energy and per-
```yaml
trainer: forces # Use the ForcesTrainer

task:
dataset: oc22_lmdb # Use the OC22LmdbDataset
dataset:
format: oc22_lmdb # Use the OC22LmdbDataset
...
```
You can find examples configuration files in [`configs/oc22/s2ef`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/s2ef).
Expand All @@ -332,8 +332,8 @@ You can find examples configuration files in [`configs/oc22/s2ef`](https://githu
Training on OC20 total energies whether independently or jointly with OC22 requires a path to the `oc20_ref` (download link provided below) to be specified in the configuration file. These are necessary to convert OC20 adsorption energies into their corresponding total energies. The following changes in the configuration file capture these changes:

```yaml
task:
dataset: oc22_lmdb
dataset:
format: oc22_lmdb
...

dataset:
Expand Down Expand Up @@ -382,10 +382,8 @@ If your data is already in an [ASE Database](https://databases.fysik.dtu.dk/ase/
To use this dataset, we will just have to change our config files to use the ASE DB Dataset rather than the LMDB Dataset:

```yaml
task:
dataset: ase_db

dataset:
format: ase_db
train:
src: # The path/address to your ASE DB
connect_args:
Expand Down Expand Up @@ -420,10 +418,8 @@ It is possible to train/predict directly on ASE-readable files. This is only rec
This dataset assumes a single structure will be obtained from each file:

```yaml
task:
dataset: ase_read

dataset:
format: ase_read
train:
src: # The folder that contains ASE-readable files
pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif".
Expand All @@ -443,10 +439,8 @@ dataset:
This dataset supports reading files that each contain multiple structure (for example, an ASE .traj file). Using an index file, which tells the dataset how many structures each file contains, is recommended. Otherwise, the dataset is forced to load every file at startup and count the number of structures!

```yaml
task:
dataset: ase_read_multi

dataset:
format: ase_read_multi
train:
index_file: Filepath to an index file which contains each filename and the number of structures in each file. e.g.:
/path/to/relaxation1.traj 200
Expand Down
18 changes: 16 additions & 2 deletions ocpmodels/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,12 @@ def check_traj_files(batch, traj_dir) -> bool:
if traj_dir is None:
return False
traj_dir = Path(traj_dir)
traj_files = [traj_dir / f"{id}.traj" for id in batch.sid.tolist()]
sid_list = (
batch.sid.tolist()
if isinstance(batch.sid, torch.Tensor)
else list(batch.sid)
)
traj_files = [traj_dir / f"{sid}.traj" for sid in sid_list]
return all(fl.exists() for fl in traj_files)


Expand Down Expand Up @@ -1204,13 +1209,22 @@ def update_config(base_config):
are now. Update old configs to fit the new expected structure.
"""
config = copy.deepcopy(base_config)
config["dataset"]["format"] = config["task"].get("dataset", "lmdb")

# If config["dataset"]["format"] is missing, get it from the task (legacy location).
# If it is not there either, default to LMDB.
config["dataset"]["format"] = config["dataset"].get(
"format", config["task"].get("dataset", "lmdb")
)

### Read task based off config structure, similar to OCPCalculator.
if config["task"]["dataset"] in [
"trajectory_lmdb",
"lmdb",
"trajectory_lmdb_v2",
"oc22_lmdb",
"ase_read",
"ase_read_multi",
"ase_db",
]:
task = "s2ef"
elif config["task"]["dataset"] == "single_point_lmdb":
Expand Down
33 changes: 33 additions & 0 deletions ocpmodels/datasets/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Copyright (c) Facebook, Inc. and its affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations

import typing

if typing.TYPE_CHECKING:
from torch_geometric.data import Data


def rename_data_object_keys(
data_object: Data, key_mapping: dict[str, str]
) -> Data:
"""Rename data object keys

Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = key_mapping[_property]
if new_property not in data_object:
data_object[new_property] = data_object[_property]
del data_object[_property]

return data_object
Loading