From 1bee0d71cc96e15293c49382c6ad80840ca5dc57 Mon Sep 17 00:00:00 2001 From: Misko Date: Wed, 21 Aug 2024 17:07:47 -0700 Subject: [PATCH] Add check to max num atoms (#817) * add assert for max_num_atoms * add test to make sure we are properly checking for max_num_elements * fix post merge --- .../models/equiformer_v2/equiformer_v2.py | 3 +++ src/fairchem/core/models/escn/escn.py | 3 +++ tests/core/e2e/test_s2ef.py | 20 +++++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py index 978d4c226..61b62be16 100644 --- a/src/fairchem/core/models/equiformer_v2/equiformer_v2.py +++ b/src/fairchem/core/models/equiformer_v2/equiformer_v2.py @@ -397,6 +397,9 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: self.dtype = data.pos.dtype self.device = data.pos.device atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" graph = self.generate_graph( data, enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, diff --git a/src/fairchem/core/models/escn/escn.py b/src/fairchem/core/models/escn/escn.py index 54b1992f4..6eb95947a 100644 --- a/src/fairchem/core/models/escn/escn.py +++ b/src/fairchem/core/models/escn/escn.py @@ -235,6 +235,9 @@ def forward(self, data): start_time = time.time() atomic_numbers = data.atomic_numbers.long() + assert ( + atomic_numbers.max().item() < self.max_num_elements + ), "Atomic number exceeds that given in model config" num_atoms = len(atomic_numbers) graph = self.generate_graph(data) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 6b83749c0..2f7dfa373 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -170,6 +170,26 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): input_yaml=configs["equiformer_v2"], ) + def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + with pytest.raises(AssertionError): + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"backbone": {"max_num_elements": 2}}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2_hydra"], + ) + @pytest.mark.parametrize( ("world_size", "ddp"), [