Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ase dataset updates #622

Merged
merged 71 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
826598f
minor cleanup of lmbddatabase
lbluque Jan 17, 2024
324a645
ase dataset compat for unified trainer and cleanup
lbluque Jan 18, 2024
6bb3b81
typo in docstring
lbluque Jan 18, 2024
b4614c4
key_mapping docstring
lbluque Jan 19, 2024
d736b00
add stress to atoms_to_graphs.py and test
lbluque Jan 19, 2024
0a17008
allow adding target properties in atoms.info
lbluque Jan 19, 2024
3a7f810
test using generic tensor property in ase_datasets
lbluque Jan 23, 2024
f47a0b8
minor docstring/comments
lbluque Jan 23, 2024
c2a789e
handle stress in voigt notation in metadata guesser
lbluque Jan 23, 2024
47f4578
handle scalar generic values in a2g
lbluque Jan 24, 2024
48dc7d0
clean up ase dataset unit tests
lbluque Jan 24, 2024
8549411
allow .aselmdb extensions
lbluque Jan 25, 2024
3371cae
fix minor bugs in lmdb database and update tests
lbluque Jan 25, 2024
a0a2b2e
make connect_db staticmethod
lbluque Jan 25, 2024
237f000
remove redundant methods and make some private
lbluque Jan 25, 2024
cae0765
allow a list of paths in AseDBdataset
lbluque Jan 26, 2024
dd0b5fc
remove sprinkled print statement
lbluque Jan 26, 2024
303120a
remove deprecated transform kwarg
lbluque Jan 29, 2024
56df36d
fix doctring typo
lbluque Jan 29, 2024
597e421
rename keys function
lbluque Jan 29, 2024
11bd455
fix missing comma in tests
lbluque Jan 29, 2024
07f2172
set default r_edges in a2g in AseDatasets to false
lbluque Jan 29, 2024
d99d383
simple unit-test for good measure
lbluque Jan 29, 2024
18fd2f1
call _get_row directly
lbluque Jan 31, 2024
fd30b43
[wip] allow string sids
lbluque Feb 1, 2024
77a40dd
raise a helpful error if AseAtomsAdaptor not available
lbluque Feb 2, 2024
c441734
remove db extension in filepaths
lbluque Feb 9, 2024
5b13296
set logger to info level when trying to read non db files, remove print
lbluque Feb 16, 2024
242b54f
set logging.debug to avoid saturating logs
lbluque Feb 17, 2024
6c678f1
Update documentation for dataset config changes
emsunshine Feb 26, 2024
fd4d3e8
Update atoms_to_graphs.py
emsunshine Feb 26, 2024
61ffef3
Update test_ase_datasets.py
emsunshine Feb 26, 2024
e3ea559
Update test_ase_datasets.py
emsunshine Feb 26, 2024
21ccf6a
Update test_atoms_to_graphs.py
emsunshine Feb 26, 2024
b8a4c2f
Update test_atoms_to_graphs.py
emsunshine Feb 26, 2024
d0cf20b
Merge branch 'main' into ase_data_updates
lbluque Feb 26, 2024
ec17ce8
case for explicit a2g_args None values
lbluque Feb 27, 2024
8b3cfac
Merge remote-tracking branch 'origin/ase_data_updates' into ase_data_…
lbluque Feb 27, 2024
01863dd
Update update_config()
emsunshine Feb 27, 2024
1c5ca26
Update utils.py
emsunshine Feb 27, 2024
90a6f6e
Update utils.py
emsunshine Feb 27, 2024
885deba
Update ocp_trainer.py
emsunshine Feb 27, 2024
0903f03
Update ocp_trainer.py
emsunshine Feb 27, 2024
17ca6a9
Update ocp_trainer.py
emsunshine Feb 27, 2024
c4ca1b0
Update TRAIN.md
emsunshine Feb 27, 2024
1fdc538
Merge branch 'main' into dataset-config-changes-documentation
emsunshine Feb 27, 2024
ce52b2f
fix concatenating predictions
lbluque Feb 27, 2024
5741907
check if keys exist in atoms.info
lbluque Feb 27, 2024
7f7c0b4
Merge branch 'ase_data_updates' into dataset-config-changes-documenta…
emsunshine Feb 28, 2024
068b053
Update test_ase_datasets.py
emsunshine Feb 28, 2024
987ba9f
use list() to cast all batch.sid/fid
lbluque Mar 5, 2024
3b4ad43
Merge pull request #630 from Open-Catalyst-Project/dataset-config-cha…
lbluque Mar 5, 2024
7995b5e
correctly stack predictions
lbluque Mar 6, 2024
3b6e2f9
Merge branch 'main' into ase_data_updates
lbluque Mar 12, 2024
f0982bb
raise error on empty datasets
lbluque Mar 19, 2024
56531d7
raise ValueError instead of exception
lbluque Mar 19, 2024
b9e758d
code cleanup
lbluque Mar 19, 2024
f6bb5d5
rename get_atoms object -> get_atoms for brevity
lbluque Mar 19, 2024
cdc509a
merge upstream
lbluque Mar 22, 2024
2f6ac22
revert to raise keyerror when data_keys are missing
lbluque Mar 22, 2024
b426842
cast tensors to list using tolist and vstack relaxation pos
lbluque Mar 22, 2024
0709e46
remove r_energy, r_forces, r_stress and r_data_keys from test_dataset…
lbluque Mar 22, 2024
310468d
fix test_dataset key
lbluque Mar 23, 2024
2422bb9
fix test_dataset key!
lbluque Mar 23, 2024
3f2f4bb
revert to not setting a2g_args dataset keys
lbluque Mar 26, 2024
ac3c1c3
fix debug predict logic
mshuaibii Mar 26, 2024
a4087a7
support numpy 1.26
mshuaibii Mar 28, 2024
07ea92f
fix numpy version
mshuaibii Mar 28, 2024
47f47e2
revert write_pos
mshuaibii Mar 28, 2024
ca9dbaf
no list casting on batch lists
lbluque Mar 28, 2024
bdbba48
pretty logging
lbluque Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 82 additions & 65 deletions ocpmodels/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

import bisect
import copy
import functools
import glob
import logging
import os
import warnings
from abc import ABC, abstractmethod
from functools import cache, reduce
from glob import glob
from pathlib import Path
from typing import List
from typing import Any, Callable, Optional

import ase
import numpy as np
Expand All @@ -18,6 +20,7 @@
from ocpmodels.common.registry import registry
from ocpmodels.datasets.lmdb_database import LMDBDatabase
from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata
from ocpmodels.modules.transforms import DataTransforms
from ocpmodels.preprocessing import AtomsToGraphs


Expand Down Expand Up @@ -65,33 +68,38 @@ class AseAtomsDataset(Dataset, ABC):
"""

def __init__(
self, config, transform=None, atoms_transform=apply_one_tags
self,
config: dict,
atoms_transform: Callable[
[ase.Atoms, Any, ...], ase.Atoms
] = apply_one_tags,
transform=None, # is this deprecated?
lbluque marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
self.config = config

a2g_args = config.get("a2g_args", {})
if a2g_args is None:
a2g_args = {}
lbluque marked this conversation as resolved.
Show resolved Hide resolved

# Make sure we always include PBC info in the resulting atoms objects
a2g_args["r_pbc"] = True
self.a2g = AtomsToGraphs(**a2g_args)

self.transform = transform
self.key_mapping = self.config.get("key_mapping", None)
self.transforms = DataTransforms(self.config.get("transforms", {}))

self.atoms_transform = atoms_transform

if self.config.get("keep_in_memory", False):
self.__getitem__ = functools.cache(self.__getitem__)
self.__getitem__ = cache(self.__getitem__)

self.ids = self.load_dataset_get_ids(config)
self.ids = self._load_dataset_get_ids(config)

def __len__(self) -> int:
return len(self.ids)

def __getitem__(self, idx):
# Handle slicing
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self.ids)))]
return [self[i] for i in range(*idx.indices(len(self)))]

# Get atoms object via derived class method
atoms = self.get_atoms_object(self.ids[idx])
Expand All @@ -105,10 +113,10 @@ def __getitem__(self, idx):
sid = atoms.info.get("sid", self.ids[idx])
try:
sid = tensor([sid])
except (RuntimeError, ValueError, TypeError):
warnings.warn(
"Supplied sid is not numeric (or missing). Using dataset indices instead."
)
except:
sid = tensor([idx])

fid = atoms.info.get("fid", tensor([0]))
Expand All @@ -118,11 +126,17 @@ def __getitem__(self, idx):
data_object.fid = fid
data_object.natoms = len(atoms)

if self.key_mapping is not None:
for _property in self.key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = self.key_mapping[_property]
if new_property not in data_object:
data_object[new_property] = data_object[_property]
del data_object[_property]
lbluque marked this conversation as resolved.
Show resolved Hide resolved

# Transform data object
if self.transform is not None:
data_object = self.transform(
data_object, **self.config.get("transform_args", {})
)
data_object = self.transforms(data_object)

if self.config.get("include_relaxed_energy", False):
data_object.y_relaxed = self.get_relaxed_energy(self.ids[idx])
Expand All @@ -137,7 +151,7 @@ def get_atoms_object(self, identifier):
)

@abstractmethod
def load_dataset_get_ids(self, config):
def _load_dataset_get_ids(self, config):
# This function should return a list of ids that can be used to index into the database
raise NotImplementedError(
"Every ASE dataset needs to declare a function to load the dataset and return a list of ids."
Expand All @@ -147,7 +161,7 @@ def close_db(self) -> None:
# This method is sometimes called by a trainer
pass

def guess_target_metadata(self, num_samples: int = 100):
def get_metadata(self, num_samples: int = 100):
metadata = {}

if num_samples < len(self):
Expand All @@ -169,9 +183,6 @@ def guess_target_metadata(self, num_samples: int = 100):

return metadata

def get_metadata(self):
return self.guess_target_metadata()


@registry.register_dataset("ase_read")
class AseReadDataset(AseAtomsDataset):
Expand All @@ -196,7 +207,7 @@ class AseReadDataset(AseAtomsDataset):
default options will work for most users

If you are using this for a training dataset, set
"r_energy":True and/or "r_forces":True as appropriate
"r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate
In that case, energy/forces must be in the files you read (ex. OUTCAR)

ase_read_args (dict): Keyword arguments for ase.io.read()
Expand All @@ -213,14 +224,18 @@ class AseReadDataset(AseAtomsDataset):

transform_args (dict): Additional keyword arguments for the transform callable

key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used
in the model with the corresponding property as it was named in the dataset. Only need to use if
the name is different.

atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms
object. Useful for applying tags, for example.

transform (callable, optional): Additional preprocessing function for the Data object

"""

def load_dataset_get_ids(self, config) -> List[Path]:
def _load_dataset_get_ids(self, config) -> list[Path]:
self.ase_read_args = config.get("ase_read_args", {})

if ":" in self.ase_read_args.get("index", ""):
Expand Down Expand Up @@ -286,7 +301,7 @@ class AseReadMultiStructureDataset(AseAtomsDataset):
default options will work for most users

If you are using this for a training dataset, set
"r_energy":True and/or "r_forces":True as appropriate
"r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate
In that case, energy/forces must be in the files you read (ex. OUTCAR)

ase_read_args (dict): Keyword arguments for ase.io.read()
Expand All @@ -305,13 +320,17 @@ class AseReadMultiStructureDataset(AseAtomsDataset):

transform_args (dict): Additional keyword arguments for the transform callable

key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used
in the model with the corresponding property as it was named in the dataset. Only need to use if
the name is different.

atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms
object. Useful for applying tags, for example.

transform (callable, optional): Additional preprocessing function for the Data object
"""

def load_dataset_get_ids(self, config):
def _load_dataset_get_ids(self, config):
self.ase_read_args = config.get("ase_read_args", {})
if not hasattr(self.ase_read_args, "index"):
self.ase_read_args["index"] = ":"
Expand Down Expand Up @@ -374,32 +393,6 @@ def get_relaxed_energy(self, identifier):
return relaxed_atoms.get_potential_energy(apply_constraint=False)


class dummy_list(list):
def __init__(self, max) -> None:
self.max = max
return

def __len__(self):
return self.max

def __getitem__(self, idx):
# Handle slicing
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(self.max))]

# Cast idx as int since it could be a tensor index
idx = int(idx)

# Handle negative indices (referenced from end)
if idx < 0:
idx += self.max

if 0 <= idx < self.max:
return idx
else:
raise IndexError


@registry.register_dataset("ase_db")
class AseDBDataset(AseAtomsDataset):
"""
Expand Down Expand Up @@ -435,7 +428,7 @@ class AseDBDataset(AseAtomsDataset):
default options will work for most users

If you are using this for a training dataset, set
"r_energy":True and/or "r_forces":True as appropriate
"r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate
In that case, energy/forces must be in the database

keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need
Expand All @@ -444,23 +437,34 @@ class AseDBDataset(AseAtomsDataset):

atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable

transform_args (dict): Additional keyword arguments for the transform callable
transforms (dict[str, dict]): Dictionary specifying data transforms as {transform_function: config}
where config is a dictionary specifying arguments to the transform_function

key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used
in the model with the corresponding property as it was named in the dataset. Only need to use if
the name is different.

atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms
object. Useful for applying tags, for example.

transform (callable, optional): Additional preprocessing function for the Data object
transform (callable, optional): deprecated?
lbluque marked this conversation as resolved.
Show resolved Hide resolved
"""

def load_dataset_get_ids(self, config) -> dummy_list:
def _load_dataset_get_ids(self, config: dict) -> list[int]:
if isinstance(config["src"], list):
filepaths = config["src"]
if os.path.isdir(config["src"][0]):
filepaths = reduce(
lambda x, y: x + y,
(glob(f"{path}/*db") for path in config["src"]),
)
else:
filepaths = config["src"]
elif os.path.isfile(config["src"]):
filepaths = [config["src"]]
elif os.path.isdir(config["src"]):
filepaths = glob.glob(f'{config["src"]}/*')
filepaths = glob(f'{config["src"]}/*db')
lbluque marked this conversation as resolved.
Show resolved Hide resolved
else:
filepaths = glob.glob(config["src"])
filepaths = glob(config["src"])

self.dbs = []

Expand Down Expand Up @@ -488,16 +492,24 @@ def load_dataset_get_ids(self, config) -> dummy_list:
if hasattr(db, "ids") and self.select_args == {}:
self.db_ids.append(db.ids)
else:
# this is the slow alternative
self.db_ids.append(
[row.id for row in db.select(**self.select_args)]
)

idlens = [len(ids) for ids in self.db_ids]
self._idlen_cumulative = np.cumsum(idlens).tolist()

return dummy_list(sum(idlens))
return list(range(sum(idlens)))

def get_atoms_object(self, idx: int) -> ase.Atoms:
"""Get atoms object corresponding to datapoint idx. Useful to read other properties not in data object.
Args:
idx (int): index in dataset

def get_atoms_object(self, idx):
Returns:
atoms: ASE atoms corresponding to datapoint idx
"""
# Figure out which db this should be indexed from.
db_idx = bisect.bisect(self._idlen_cumulative, idx)

Expand All @@ -507,20 +519,25 @@ def get_atoms_object(self, idx):
el_idx = idx - self._idlen_cumulative[db_idx - 1]
assert el_idx >= 0

atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx])
atoms_row = self.dbs[db_idx].get(self.db_ids[db_idx][el_idx])
atoms = atoms_row.toatoms()

# put data back into atoms info
if isinstance(atoms_row.data, dict):
atoms.info.update(atoms_row.data)

return atoms

def connect_db(self, address, connect_args={}):
@staticmethod
def connect_db(
address: str | Path, connect_args: Optional[dict] = None
) -> ase.db.core.Database:
if connect_args is None:
connect_args = {}
db_type = connect_args.get("type", "extract_from_name")
if db_type == "lmdb" or (
db_type == "extract_from_name" and address.split(".")[-1] == "lmdb"
if db_type in ("lmdb", "aselmdb") or (
db_type == "extract_from_name"
and str(address).split(".")[-1] in ("lmdb", "aselmdb")
):
return LMDBDatabase(address, readonly=True, **connect_args)
else:
Expand All @@ -531,12 +548,12 @@ def close_db(self) -> None:
if hasattr(db, "close"):
db.close()

def get_metadata(self):
def get_metadata(self, num_samples: int = 100) -> dict:
logging.warning(
"You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!"
)
if self.dbs[0].metadata == {}:
return self.guess_target_metadata()
return super().get_metadata(num_samples)
else:
return copy.deepcopy(self.dbs[0].metadata)

Expand Down
Loading