diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py index 13e07749d..0b79dfe07 100644 --- a/src/fairchem/core/datasets/base_dataset.py +++ b/src/fairchem/core/datasets/base_dataset.py @@ -13,7 +13,6 @@ from typing import ( TYPE_CHECKING, Any, - NamedTuple, TypeVar, ) @@ -34,10 +33,6 @@ T_co = TypeVar("T_co", covariant=True) -class DatasetMetadata(NamedTuple): - natoms: ArrayLike | None = None - - class UnsupportedDatasetError(ValueError): pass @@ -71,16 +66,14 @@ def __len__(self) -> int: return self.num_samples def metadata_hasattr(self, attr) -> bool: - if self._metadata is None: - return False - return hasattr(self._metadata, attr) + return attr in self._metadata @cached_property def indices(self): return np.arange(self.num_samples, dtype=int) @cached_property - def _metadata(self) -> DatasetMetadata: + def _metadata(self) -> dict[str, ArrayLike]: # logic to read metadata file here metadata_npzs = [] if self.config.get("metadata_path", None) is not None: @@ -101,26 +94,25 @@ def _metadata(self) -> DatasetMetadata: logging.warning( f"Could not find dataset metadata.npz files in '{self.paths}'" ) - return None + return {} + + metadata = { + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in metadata_npzs[0] + } - metadata = DatasetMetadata( - **{ - field: np.concatenate([metadata[field] for metadata in metadata_npzs]) - for field in DatasetMetadata._fields - } - ) assert np.issubdtype( - metadata.natoms.dtype, np.integer - ), f"Metadata natoms must be an integer type! not {metadata.natoms.dtype}" - assert metadata.natoms.shape[0] == len( + metadata["natoms"].dtype, np.integer + ), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}" + assert metadata["natoms"].shape[0] == len( self ), "Loaded metadata and dataset size mismatch." return metadata def get_metadata(self, attr, idx): - if self._metadata is not None: - metadata_attr = getattr(self._metadata, attr) + if attr in self._metadata: + metadata_attr = self._metadata[attr] if isinstance(idx, list): return [metadata_attr[_idx] for _idx in idx] return metadata_attr[idx] @@ -134,7 +126,7 @@ def __init__( self, dataset: BaseDataset, indices: Sequence[int], - metadata: DatasetMetadata | None = None, + metadata: dict[str, ArrayLike], ) -> None: super().__init__(dataset, indices) self.metadata = metadata @@ -143,7 +135,7 @@ def __init__( self.config = dataset.config @cached_property - def _metadata(self) -> DatasetMetadata: + def _metadata(self) -> dict[str, ArrayLike]: return self.dataset._metadata def get_metadata(self, attr, idx): @@ -183,6 +175,7 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset: g.manual_seed(seed) dataset = dataset_cls(current_split_config) + # Get indices of the dataset indices = dataset.indices max_atoms = current_split_config.get("max_atoms", None) @@ -191,6 +184,24 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset: raise ValueError("Cannot use max_atoms without dataset metadata") indices = indices[dataset.get_metadata("natoms", indices) <= max_atoms] + for subset_to in current_split_config.get("subset_to", []): + if not dataset.metadata_hasattr(subset_to["metadata_key"]): + raise ValueError( + f"Cannot use {subset_to} without dataset metadata key {subset_to['metadata_key']}" + ) + if subset_to["op"] == "abs_le": + indices = indices[ + np.abs(dataset.get_metadata(subset_to["metadata_key"], indices)) + <= subset_to["rhv"] + ] + elif subset_to["op"] == "in": + indices = indices[ + np.isin( + dataset.get_metadata(subset_to["metadata_key"], indices), + subset_to["rhv"], + ) + ] + # Apply dataset level transforms # TODO is no_shuffle mutually exclusive though? or what is the purpose of no_shuffle? first_n = current_split_config.get("first_n") @@ -208,11 +219,17 @@ def create_dataset(config: dict[str, Any], split: str) -> Subset: # shuffle all datasets by default to avoid biasing the sampling in concat dataset # TODO only shuffle if split is train max_index = sample_n - indices = indices[randperm(len(indices), generator=g)] + indices = ( + indices + if len(indices) == 1 + else indices[randperm(len(indices), generator=g)] + ) else: max_index = len(indices) indices = ( - indices if no_shuffle else indices[randperm(len(indices), generator=g)] + indices + if (no_shuffle or len(indices) == 1) + else indices[randperm(len(indices), generator=g)] ) if max_index > len(indices): diff --git a/tests/core/common/test_data_parallel_batch_sampler.py b/tests/core/common/test_data_parallel_batch_sampler.py index 6bd8effe2..1bc594c9f 100644 --- a/tests/core/common/test_data_parallel_batch_sampler.py +++ b/tests/core/common/test_data_parallel_batch_sampler.py @@ -7,11 +7,15 @@ from __future__ import annotations -from contextlib import contextmanager -from pathlib import Path import functools import tempfile -from typing import TypeVar +from contextlib import contextmanager +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, +) import numpy as np import pytest @@ -23,7 +27,10 @@ UnsupportedDatasetError, _balanced_partition, ) -from fairchem.core.datasets.base_dataset import BaseDataset, DatasetMetadata + +if TYPE_CHECKING: + from numpy.typing import ArrayLike +from fairchem.core.datasets.base_dataset import BaseDataset DATA = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] SIZE_ATOMS = [2, 20, 3, 51, 10, 11, 41, 31, 13, 14] @@ -41,8 +48,8 @@ def _temp_file(name: str): def valid_dataset(): class _Dataset(BaseDataset): @functools.cached_property - def _metadata(self) -> DatasetMetadata: - return DatasetMetadata(natoms=np.array(SIZE_ATOMS)) + def _metadata(self) -> dict[str, ArrayLike]: + return {"natoms": np.array(SIZE_ATOMS)} def __init__(self, data) -> None: super().__init__(config={}) @@ -56,7 +63,7 @@ def __getitem__(self, idx): def get_metadata(self, attr, idx): assert attr == "natoms" - metadata_attr = getattr(self._metadata, attr) + metadata_attr = self._metadata[attr] if isinstance(idx, list): return [metadata_attr[_idx] for _idx in idx] return metadata_attr[idx] @@ -68,19 +75,19 @@ def get_metadata(self, attr, idx): def valid_path_dataset(): class _Dataset(BaseDataset): @functools.cached_property - def _metadata(self) -> DatasetMetadata: + def _metadata(self) -> dict[str, ArrayLike]: return self.metadata def __init__(self, data, fpath: Path) -> None: super().__init__(config={}) self.data = data - self.metadata = DatasetMetadata(natoms=np.load(fpath)["natoms"]) + self.metadata = {"natoms": np.load(fpath)["natoms"]} def __len__(self): return len(self.data) def __getitem__(self, idx): - metadata_attr = getattr(self._metadata, "natoms") + metadata_attr = self._metadata["natoms"] if isinstance(idx, list): return [metadata_attr[_idx] for _idx in idx] return metadata_attr[idx] @@ -96,7 +103,6 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_path_dataset(): class _Dataset(BaseDataset): - def __init__(self, data) -> None: super().__init__(config={}) self.data = data @@ -114,7 +120,6 @@ def __getitem__(self, idx): @pytest.fixture() def invalid_dataset(): class _Dataset(BaseDataset): - def __init__(self, data) -> None: super().__init__(config={}) self.data = data diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py index 1dc17bcb3..0bd20bafd 100644 --- a/tests/core/datasets/test_create_dataset.py +++ b/tests/core/datasets/test_create_dataset.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import os +import tempfile + import numpy as np import pytest from fairchem.core.datasets import LMDBDatabase, create_dataset from fairchem.core.datasets.base_dataset import BaseDataset -import tempfile from fairchem.core.trainers.base_trainer import BaseTrainer @@ -12,12 +15,21 @@ def lmdb_database(structures): with tempfile.TemporaryDirectory() as tmpdirname: num_atoms = [] + mod2 = [] + mod3 = [] asedb_fn = f"{tmpdirname}/asedb.lmdb" with LMDBDatabase(asedb_fn) as database: for i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) num_atoms.append(len(atoms)) - np.savez(f"{tmpdirname}/metadata.npz", natoms=num_atoms) + mod2.append(len(atoms) % 2) + mod3.append(len(atoms) % 3) + np.savez( + f"{tmpdirname}/metadata.npz", + natoms=num_atoms, + mod2=mod2, + mod3=mod3, + ) yield asedb_fn @@ -76,6 +88,40 @@ def get_dataloader(self, *args, **kwargs): assert len(t.val_dataset) == 3 +def test_subset_to(structures, lmdb_database): + config = { + "format": "ase_db", + "src": str(lmdb_database), + "subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 10}], + } + + assert len(create_dataset(config, split="train")) == len(structures) + + # only select those that have mod2==0 + config = { + "format": "ase_db", + "src": str(lmdb_database), + "subset_to": [{"op": "abs_le", "metadata_key": "mod2", "rhv": 0}], + } + assert len(create_dataset(config, split="train")) == len( + [s for s in structures if len(s) % 2 == 0] + ) + + # only select those that have mod2==0 and mod3==0 + config = { + "format": "ase_db", + "src": str(lmdb_database), + "subset_to": [ + {"op": "abs_le", "metadata_key": "mod2", "rhv": 0}, + {"op": "abs_le", "metadata_key": "mod2", "rhv": 0}, + ], + } + assert len(create_dataset(config, split="train")) == len( + [s for s in structures if len(s) % 2 == 0] + ) + assert len([s for s in structures if len(s) % 2 == 0]) > 0 + + @pytest.mark.parametrize("max_atoms", [3, None]) @pytest.mark.parametrize( "key, value", [("first_n", 2), ("sample_n", 2), ("no_shuffle", True)] @@ -94,7 +140,7 @@ def test_create_dataset(key, value, max_atoms, structures, lmdb_database): structures = [s for s in structures if len(s) <= max_atoms] assert all( natoms <= max_atoms - for natoms in dataset.metadata.natoms[range(len(dataset))] + for natoms in dataset.metadata["natoms"][range(len(dataset))] ) if key == "first_n": # this assumes first_n are not shuffled assert all(