From a27e658ee1e10a21e9cefaae37882720f47e8b54 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 20 Aug 2024 22:06:05 +0000 Subject: [PATCH] add test to make sure we are properly checking for max_num_elements --- tests/core/e2e/test_s2ef.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 6b83749c0..f5484928b 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -170,6 +170,26 @@ def test_use_pbc_single(self, configs, tutorial_val_src, torch_deterministic): input_yaml=configs["equiformer_v2"], ) + def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + extra_args = {"seed": 0} + with pytest.raises(AssertionError): + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "model": {"max_num_elements": 2}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with=extra_args, + input_yaml=configs["equiformer_v2"], + ) + @pytest.mark.parametrize( ("world_size", "ddp"), [