Skip to content

Commit

Permalink
Filter based on additional metadata fields (#948)
Browse files Browse the repository at this point in the history
* add in info fields

add in info fields and limits

* fix up datasetmetadata

* lint

* add some tests

* replace DatasetMetadata with dict

* remove datasetmetadata

* remove datasetmetadata

* remove info_fields, there is already a way to do this using r_data_keys
  • Loading branch information
misko authored Dec 19, 2024
1 parent 83e1a53 commit 3a9732a
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 40 deletions.
67 changes: 42 additions & 25 deletions src/fairchem/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from typing import (
TYPE_CHECKING,
Any,
NamedTuple,
TypeVar,
)

Expand All @@ -34,10 +33,6 @@
T_co = TypeVar("T_co", covariant=True)


class DatasetMetadata(NamedTuple):
natoms: ArrayLike | None = None


class UnsupportedDatasetError(ValueError):
pass

Expand Down Expand Up @@ -71,16 +66,14 @@ def __len__(self) -> int:
return self.num_samples

def metadata_hasattr(self, attr) -> bool:
if self._metadata is None:
return False
return hasattr(self._metadata, attr)
return attr in self._metadata

@cached_property
def indices(self):
return np.arange(self.num_samples, dtype=int)

@cached_property
def _metadata(self) -> DatasetMetadata:
def _metadata(self) -> dict[str, ArrayLike]:
# logic to read metadata file here
metadata_npzs = []
if self.config.get("metadata_path", None) is not None:
Expand All @@ -101,26 +94,25 @@ def _metadata(self) -> DatasetMetadata:
logging.warning(
f"Could not find dataset metadata.npz files in '{self.paths}'"
)
return None
return {}

metadata = {
field: np.concatenate([metadata[field] for metadata in metadata_npzs])
for field in metadata_npzs[0]
}

metadata = DatasetMetadata(
**{
field: np.concatenate([metadata[field] for metadata in metadata_npzs])
for field in DatasetMetadata._fields
}
)
assert np.issubdtype(
metadata.natoms.dtype, np.integer
), f"Metadata natoms must be an integer type! not {metadata.natoms.dtype}"
assert metadata.natoms.shape[0] == len(
metadata["natoms"].dtype, np.integer
), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}"
assert metadata["natoms"].shape[0] == len(
self
), "Loaded metadata and dataset size mismatch."

return metadata

def get_metadata(self, attr, idx):
if self._metadata is not None:
metadata_attr = getattr(self._metadata, attr)
if attr in self._metadata:
metadata_attr = self._metadata[attr]
if isinstance(idx, list):
return [metadata_attr[_idx] for _idx in idx]
return metadata_attr[idx]
Expand All @@ -134,7 +126,7 @@ def __init__(
self,
dataset: BaseDataset,
indices: Sequence[int],
metadata: DatasetMetadata | None = None,
metadata: dict[str, ArrayLike],
) -> None:
super().__init__(dataset, indices)
self.metadata = metadata
Expand All @@ -143,7 +135,7 @@ def __init__(
self.config = dataset.config

@cached_property
def _metadata(self) -> DatasetMetadata:
def _metadata(self) -> dict[str, ArrayLike]:
return self.dataset._metadata

def get_metadata(self, attr, idx):
Expand Down Expand Up @@ -183,6 +175,7 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset:
g.manual_seed(seed)

dataset = dataset_cls(current_split_config)

# Get indices of the dataset
indices = dataset.indices
max_atoms = current_split_config.get("max_atoms", None)
Expand All @@ -191,6 +184,24 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset:
raise ValueError("Cannot use max_atoms without dataset metadata")
indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms]

for subset_to in current_split_config.get("subset_to", []):
if not dataset.metadata_hasattr(subset_to["metadata_key"]):
raise ValueError(
f"Cannot use {subset_to} without dataset metadata key {subset_to['metadata_key']}"
)
if subset_to["op"] == "abs_le":
indices = indices[
np.abs(dataset.get_metadata(subset_to["metadata_key"], indices))
<= subset_to["rhv"]
]
elif subset_to["op"] == "in":
indices = indices[
np.isin(
dataset.get_metadata(subset_to["metadata_key"], indices),
subset_to["rhv"],
)
]

# Apply dataset level transforms
# TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle?
first_n = current_split_config.get("first_n")
Expand All @@ -208,11 +219,17 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset:
# shuffle all datasets by default to avoid biasing the sampling in concat dataset
# TODO only shuffle if split is train
max_index = sample_n
indices = indices[randperm(len(indices), generator=g)]
indices = (
indices
if len(indices) == 1
else indices[randperm(len(indices), generator=g)]
)
else:
max_index = len(indices)
indices = (
indices if no_shuffle else indices[randperm(len(indices), generator=g)]
indices
if (no_shuffle or len(indices) == 1)
else indices[randperm(len(indices), generator=g)]
)

if max_index > len(indices):
Expand Down
29 changes: 17 additions & 12 deletions tests/core/common/test_data_parallel_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@

from __future__ import annotations

from contextlib import contextmanager
from pathlib import Path
import functools
import tempfile
from typing import TypeVar
from contextlib import contextmanager
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
TypeVar,
)

import numpy as np
import pytest
Expand All @@ -23,7 +27,10 @@
UnsupportedDatasetError,
_balanced_partition,
)
from fairchem.core.datasets.base_dataset import BaseDataset, DatasetMetadata

if TYPE_CHECKING:
from numpy.typing import ArrayLike
from fairchem.core.datasets.base_dataset import BaseDataset

DATA = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
SIZE_ATOMS = [2, 20, 3, 51, 10, 11, 41, 31, 13, 14]
Expand All @@ -41,8 +48,8 @@ def _temp_file(name: str):
def valid_dataset():
class _Dataset(BaseDataset):
@functools.cached_property
def _metadata(self) -> DatasetMetadata:
return DatasetMetadata(natoms=np.array(SIZE_ATOMS))
def _metadata(self) -> dict[str, ArrayLike]:
return {"natoms": np.array(SIZE_ATOMS)}

def __init__(self, data) -> None:
super().__init__(config={})
Expand All @@ -56,7 +63,7 @@ def __getitem__(self, idx):

def get_metadata(self, attr, idx):
assert attr == "natoms"
metadata_attr = getattr(self._metadata, attr)
metadata_attr = self._metadata[attr]
if isinstance(idx, list):
return [metadata_attr[_idx] for _idx in idx]
return metadata_attr[idx]
Expand All @@ -68,19 +75,19 @@ def get_metadata(self, attr, idx):
def valid_path_dataset():
class _Dataset(BaseDataset):
@functools.cached_property
def _metadata(self) -> DatasetMetadata:
def _metadata(self) -> dict[str, ArrayLike]:
return self.metadata

def __init__(self, data, fpath: Path) -> None:
super().__init__(config={})
self.data = data
self.metadata = DatasetMetadata(natoms=np.load(fpath)["natoms"])
self.metadata = {"natoms": np.load(fpath)["natoms"]}

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
metadata_attr = getattr(self._metadata, "natoms")
metadata_attr = self._metadata["natoms"]
if isinstance(idx, list):
return [metadata_attr[_idx] for _idx in idx]
return metadata_attr[idx]
Expand All @@ -96,7 +103,6 @@ def __getitem__(self, idx):
@pytest.fixture()
def invalid_path_dataset():
class _Dataset(BaseDataset):

def __init__(self, data) -> None:
super().__init__(config={})
self.data = data
Expand All @@ -114,7 +120,6 @@ def __getitem__(self, idx):
@pytest.fixture()
def invalid_dataset():
class _Dataset(BaseDataset):

def __init__(self, data) -> None:
super().__init__(config={})
self.data = data
Expand Down
52 changes: 49 additions & 3 deletions tests/core/datasets/test_create_dataset.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
from __future__ import annotations

import os
import tempfile

import numpy as np
import pytest

from fairchem.core.datasets import LMDBDatabase, create_dataset
from fairchem.core.datasets.base_dataset import BaseDataset
import tempfile
from fairchem.core.trainers.base_trainer import BaseTrainer


@pytest.fixture()
def lmdb_database(structures):
with tempfile.TemporaryDirectory() as tmpdirname:
num_atoms = []
mod2 = []
mod3 = []
asedb_fn = f"{tmpdirname}/asedb.lmdb"
with LMDBDatabase(asedb_fn) as database:
for i, atoms in enumerate(structures):
database.write(atoms, data=atoms.info)
num_atoms.append(len(atoms))
np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms)
mod2.append(len(atoms) % 2)
mod3.append(len(atoms) % 3)
np.savez(
f"{tmpdirname}/metadata.npz",
natoms=num_atoms,
mod2=mod2,
mod3=mod3,
)
yield asedb_fn


Expand Down Expand Up @@ -76,6 +88,40 @@ def get_dataloader(self, *args, **kwargs):
assert len(t.val_dataset) == 3


def test_subset_to(structures, lmdb_database):
config = {
"format": "ase_db",
"src": str(lmdb_database),
"subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 10}],
}

assert len(create_dataset(config, split="train")) == len(structures)

# only select those that have mod2==0
config = {
"format": "ase_db",
"src": str(lmdb_database),
"subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 0}],
}
assert len(create_dataset(config, split="train")) == len(
[s for s in structures if len(s) % 2 == 0]
)

# only select those that have mod2==0 and mod3==0
config = {
"format": "ase_db",
"src": str(lmdb_database),
"subset_to": [
{"op": "abs_le", "metadata_key": "mod2", "rhv": 0},
{"op": "abs_le", "metadata_key": "mod2", "rhv": 0},
],
}
assert len(create_dataset(config, split="train")) == len(
[s for s in structures if len(s) % 2 == 0]
)
assert len([s for s in structures if len(s) % 2 == 0]) > 0


@pytest.mark.parametrize("max_atoms", [3, None])
@pytest.mark.parametrize(
"key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)]
Expand All @@ -94,7 +140,7 @@ def test_create_dataset(key, value, max_atoms, structures, lmdb_database):
structures = [s for s in structures if len(s) <= max_atoms]
assert all(
natoms <= max_atoms
for natoms in dataset.metadata.natoms[range(len(dataset))]
for natoms in dataset.metadata["natoms"][range(len(dataset))]
)
if key == "first_n": # this assumes first_n are not shuffled
assert all(
Expand Down

0 comments on commit 3a9732a

Please sign in to comment.