From 9abc429cc7badaafc59c54e5830c0ce90dd6d223 Mon Sep 17 00:00:00 2001 From: Luis Barroso-Luque Date: Fri, 19 Jul 2024 10:45:33 -0700 Subject: [PATCH] Fix dataset logic (#771) * adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764) * adding new notebook for using fairchem models with NEBs * adding md tutorials * blocking code cells that arent needed or take too long * fix dataset config logic * add empty val/test if not defined * add empty dicts for all missing datasets --------- Co-authored-by: Brook Wander <73855115+brookwander@users.noreply.github.com> Co-authored-by: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Co-authored-by: zulissimeta <122578103+zulissimeta@users.noreply.github.com> --- src/fairchem/core/trainers/base_trainer.py | 56 +++++++++------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 12e2e61e7e..dce5099452 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -159,12 +159,18 @@ def __init__( if len(dataset) > 2: self.config["test_dataset"] = dataset[2] elif isinstance(dataset, dict): - self.config["dataset"] = dataset.get("train", {}) - self.config["val_dataset"] = dataset.get("val", {}) - self.config["test_dataset"] = dataset.get("test", {}) - self.config["relax_dataset"] = dataset.get("relax", {}) + # or {} in cases where "dataset": None is explicitly defined + self.config["dataset"] = dataset.get("train", {}) or {} + self.config["val_dataset"] = dataset.get("val", {}) or {} + self.config["test_dataset"] = dataset.get("test", {}) or {} + self.config["relax_dataset"] = dataset.get("relax", {}) or {} else: - self.config["dataset"] = dataset + self.config["dataset"] = dataset or {} + + # add empty dicts for missing datasets + for dataset_name in ("val_dataset", "test_dataset", "relax_dataset"): + if dataset_name not in self.config: + self.config[dataset_name] = {} if not is_debug and distutils.is_master(): os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True) @@ -277,19 +283,8 @@ def load_datasets(self) -> None: self.val_loader = None self.test_loader = None - # Default all of the dataset portions to {} if - # they don't exist, or are null - if not self.config.get("dataset", None): - self.config["dataset"] = {} - if not self.config.get("val_dataset", None): - self.config["val_dataset"] = {} - if not self.config.get("test_dataset", None): - self.config["test_dataset"] = {} - if not self.config.get("relax_dataset", None): - self.config["relax_dataset"] = {} - # load train, val, test datasets - if self.config["dataset"] and self.config["dataset"].get("src", None): + if "src" in self.config["dataset"]: logging.info( f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}" ) @@ -307,7 +302,7 @@ def load_datasets(self) -> None: self.train_sampler, ) - if self.config["val_dataset"]: + if "src" in self.config["val_dataset"]: if self.config["val_dataset"].get("use_train_settings", True): val_config = self.config["dataset"].copy() val_config.update(self.config["val_dataset"]) @@ -329,13 +324,8 @@ def load_datasets(self) -> None: self.val_sampler, ) - if self.config["test_dataset"]: - if ( - self.config["test_dataset"].get("use_train_settings", True) - and self.config[ - "dataset" - ] # if there's no training dataset, we have nothing to copy - ): + if "src" in self.config["test_dataset"]: + if self.config["test_dataset"].get("use_train_settings", True): test_config = self.config["dataset"].copy() test_config.update(self.config["test_dataset"]) else: @@ -407,16 +397,16 @@ def load_task(self): "outputs" ][target_name].get("level", "system") if "train_on_free_atoms" not in self.output_targets[subtarget]: - self.output_targets[subtarget][ - "train_on_free_atoms" - ] = self.config["outputs"][target_name].get( - "train_on_free_atoms", True + self.output_targets[subtarget]["train_on_free_atoms"] = ( + self.config[ + "outputs" + ][target_name].get("train_on_free_atoms", True) ) if "eval_on_free_atoms" not in self.output_targets[subtarget]: - self.output_targets[subtarget][ - "eval_on_free_atoms" - ] = self.config["outputs"][target_name].get( - "eval_on_free_atoms", True + self.output_targets[subtarget]["eval_on_free_atoms"] = ( + self.config[ + "outputs" + ][target_name].get("eval_on_free_atoms", True) ) # TODO: Assert that all targets, loss fn, metrics defined are consistent