From 65a3473889d7dfa6cc8cdd90392afe38dd283e8c Mon Sep 17 00:00:00 2001 From: Misko Date: Thu, 19 Dec 2024 00:41:36 +0000 Subject: [PATCH] replace DatasetMetadata with dict --- src/fairchem/core/datasets/base_dataset.py | 29 ++++++++-------------- tests/core/datasets/test_create_dataset.py | 2 +- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py index 0a5db9a37..a52abbb3b 100644 --- a/src/fairchem/core/datasets/base_dataset.py +++ b/src/fairchem/core/datasets/base_dataset.py @@ -33,13 +33,6 @@ T_co = TypeVar("T_co", covariant=True) -class DatasetMetadata: - def __init__(self, natoms: ArrayLike | None = None, **kwargs): - self.natoms = natoms - for key, value in kwargs.items(): - setattr(self, key, value) - - class UnsupportedDatasetError(ValueError): pass @@ -75,14 +68,14 @@ def __len__(self) -> int: def metadata_hasattr(self, attr) -> bool: if self._metadata is None: return False - return hasattr(self._metadata, attr) + return attr in self._metadata @cached_property def indices(self): return np.arange(self.num_samples, dtype=int) @cached_property - def _metadata(self) -> DatasetMetadata: + def _metadata(self) -> dict[str, ArrayLike]: # logic to read metadata file here metadata_npzs = [] if self.config.get("metadata_path", None) is not None: @@ -105,17 +98,15 @@ def _metadata(self) -> DatasetMetadata: ) return None - metadata = DatasetMetadata( - **{ - field: np.concatenate([metadata[field] for metadata in metadata_npzs]) - for field in metadata_npzs[0] - } - ) + metadata = { + field: np.concatenate([metadata[field] for metadata in metadata_npzs]) + for field in metadata_npzs[0] + } assert np.issubdtype( - metadata.natoms.dtype, np.integer - ), f"Metadata natoms must be an integer type! not {metadata.natoms.dtype}" - assert metadata.natoms.shape[0] == len( + metadata["natoms"].dtype, np.integer + ), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}" + assert metadata["natoms"].shape[0] == len( self ), "Loaded metadata and dataset size mismatch." @@ -123,7 +114,7 @@ def _metadata(self) -> DatasetMetadata: def get_metadata(self, attr, idx): if self._metadata is not None: - metadata_attr = getattr(self._metadata, attr) + metadata_attr = self._metadata[attr] if isinstance(idx, list): return [metadata_attr[_idx] for _idx in idx] return metadata_attr[idx] diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py index 02f89b593..0bd20bafd 100644 --- a/tests/core/datasets/test_create_dataset.py +++ b/tests/core/datasets/test_create_dataset.py @@ -140,7 +140,7 @@ def test_create_dataset(key, value, max_atoms, structures, lmdb_database): structures = [s for s in structures if len(s) <= max_atoms] assert all( natoms <= max_atoms - for natoms in dataset.metadata.natoms[range(len(dataset))] + for natoms in dataset.metadata["natoms"][range(len(dataset))] ) if key == "first_n": # this assumes first_n are not shuffled assert all(