Skip to content

Commit

Permalink
replace DatasetMetadata with dict
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Dec 19, 2024
1 parent 9071f5c commit 65a3473
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 20 deletions.
29 changes: 10 additions & 19 deletions src/fairchem/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@
T_co = TypeVar("T_co", covariant=True)


class DatasetMetadata:
def __init__(self, natoms: ArrayLike | None = None, **kwargs):
self.natoms = natoms
for key, value in kwargs.items():
setattr(self, key, value)


class UnsupportedDatasetError(ValueError):
pass

Expand Down Expand Up @@ -75,14 +68,14 @@ def __len__(self) -> int:
def metadata_hasattr(self, attr) -> bool:
if self._metadata is None:
return False
return hasattr(self._metadata, attr)
return attr in self._metadata

@cached_property
def indices(self):
return np.arange(self.num_samples, dtype=int)

@cached_property
def _metadata(self) -> DatasetMetadata:
def _metadata(self) -> dict[str, ArrayLike]:
# logic to read metadata file here
metadata_npzs = []
if self.config.get("metadata_path", None) is not None:
Expand All @@ -105,25 +98,23 @@ def _metadata(self) -> DatasetMetadata:
)
return None

metadata = DatasetMetadata(
**{
field: np.concatenate([metadata[field] for metadata in metadata_npzs])
for field in metadata_npzs[0]
}
)
metadata = {
field: np.concatenate([metadata[field] for metadata in metadata_npzs])
for field in metadata_npzs[0]
}

assert np.issubdtype(
metadata.natoms.dtype, np.integer
), f"Metadata natoms must be an integer type! not {metadata.natoms.dtype}"
assert metadata.natoms.shape[0] == len(
metadata["natoms"].dtype, np.integer
), f"Metadata natoms must be an integer type! not {metadata['natoms'].dtype}"
assert metadata["natoms"].shape[0] == len(
self
), "Loaded metadata and dataset size mismatch."

return metadata

def get_metadata(self, attr, idx):
if self._metadata is not None:
metadata_attr = getattr(self._metadata, attr)
metadata_attr = self._metadata[attr]
if isinstance(idx, list):
return [metadata_attr[_idx] for _idx in idx]
return metadata_attr[idx]
Expand Down
2 changes: 1 addition & 1 deletion tests/core/datasets/test_create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_create_dataset(key, value, max_atoms, structures, lmdb_database):
structures = [s for s in structures if len(s) <= max_atoms]
assert all(
natoms <= max_atoms
for natoms in dataset.metadata.natoms[range(len(dataset))]
for natoms in dataset.metadata["natoms"][range(len(dataset))]
)
if key == "first_n": # this assumes first_n are not shuffled
assert all(
Expand Down

0 comments on commit 65a3473

Please sign in to comment.