diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 4dc04915b..669449f0b 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -1102,6 +1102,35 @@ def _report_incompat_keys( return missing_keys, unexpected_keys +def match_state_dict( + model_state_dict: Mapping[str, torch.Tensor], + checkpoint_state_dict: Mapping[str, torch.Tensor], +) -> dict: + # match the model's state dict with the checkpoint state and return a new dict + # that's compatible with the models + + # Match the "module." count in the keys of model and checkpoint state_dict + # DataParallel model has 1 "module.", DistributedDataParallel has 2 "module." + # Not using either of the above two would have no "module." + + ckpt_key_count = next(iter(checkpoint_state_dict)).count("module") + mod_key_count = next(iter(model_state_dict)).count("module") + key_count_diff = mod_key_count - ckpt_key_count + + if key_count_diff > 0: + new_dict = { + key_count_diff * "module." + k: v for k, v in checkpoint_state_dict.items() + } + elif key_count_diff < 0: + new_dict = { + k[len("module.") * abs(key_count_diff) :]: v + for k, v in checkpoint_state_dict.items() + } + else: + new_dict = checkpoint_state_dict + return new_dict + + def load_state_dict( module: nn.Module, state_dict: Mapping[str, torch.Tensor], diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index 8ce8f3fcb..7380d036c 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -7,8 +7,9 @@ from __future__ import annotations +import copy import logging -from abc import ABCMeta, abstractmethod +from abc import ABC, ABCMeta, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING @@ -227,8 +228,19 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]: return +class HydraInterface(ABC): + # a hydra has a backbone and heads + @abstractmethod + def get_backbone(self) -> BackboneInterface: + raise not NotImplementedError + + @abstractmethod + def get_heads(self) -> dict[str, HeadInterface]: + raise not NotImplementedError + + @registry.register_model("hydra") -class HydraModel(nn.Module, GraphModelMixin): +class HydraModel(nn.Module, GraphModelMixin, HydraInterface): def __init__( self, backbone: dict, @@ -237,6 +249,9 @@ def __init__( ): super().__init__() self.otf_graph = otf_graph + # make a copy so we don't modify the original config + backbone = copy.deepcopy(backbone) + heads = copy.deepcopy(heads) backbone_model_name = backbone.pop("model") self.backbone: BackboneInterface = registry.get_model_class( @@ -272,3 +287,9 @@ def forward(self, data: Batch): out.update(self.output_heads[k](data, emb)) return out + + def get_backbone(self) -> BackboneInterface: + return self.backbone + + def get_heads(self) -> dict[str, HeadInterface]: + return self.output_heads diff --git a/src/fairchem/core/models/finetune_hydra.py b/src/fairchem/core/models/finetune_hydra.py new file mode 100644 index 000000000..6c271e24e --- /dev/null +++ b/src/fairchem/core/models/finetune_hydra.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import copy +import errno +import logging +import os +from enum import Enum +from typing import TYPE_CHECKING + +import torch +from torch import nn + +from fairchem.core.common.registry import registry +from fairchem.core.common.utils import load_state_dict, match_state_dict +from fairchem.core.models.base import BackboneInterface, HeadInterface, HydraInterface + +if TYPE_CHECKING: + from torch_geometric.data import Batch + +FTHYDRA_NAME = "finetune_hydra" + +class FineTuneMode(Enum): + # in DATA_ONLY, we load the entire model and only finetune on new data + DATA_ONLY = 1 + # in this mode, we only load the Backbone and feed the output of the backbone + # to new heads that are specified + RETAIN_BACKBONE_ONLY = 2 + + +def get_model_config_from_checkpoint(checkpoint_path: str) -> dict: + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError( + errno.ENOENT, "Checkpoint file not found", checkpoint_path + ) + checkpoint = torch.load(checkpoint_path) + return checkpoint["config"]["model"] + + +def load_hydra_model(checkpoint_path: str) -> HydraInterface: + if not os.path.isfile(checkpoint_path): + raise FileNotFoundError( + errno.ENOENT, "Checkpoint file not found", checkpoint_path + ) + logging.info(f"Loading checkpoint from: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path) + config = checkpoint["config"]["model"] + name = config.pop("name") + hydra_model = registry.get_model_class(name)(**config) + assert isinstance( + hydra_model, HydraInterface + ), "Can only load models with the HydraInterface" + matched_dict = match_state_dict(hydra_model.state_dict(), checkpoint["state_dict"]) + load_state_dict(hydra_model, matched_dict, strict=True) + return hydra_model + + +class FTConfig: + FT_CONFIG_NAME = "finetune_config" + STARTING_CHECKPOINT = "starting_checkpoint" + STARTING_MODEL = "starting_model" + MODE = "mode" + HEADS = "heads" + + def __init__(self, config: dict): + self.config = config + self._mode = FineTuneMode[self.config[FTConfig.MODE]] + assert ( + (FTConfig.STARTING_CHECKPOINT in self.config) + or (FTConfig.STARTING_MODEL in self.config) + ), "Either a starting checkpoint or a starting model must be provided!" + assert FTConfig.MODE in self.config + if self._mode == FineTuneMode.RETAIN_BACKBONE_ONLY: + # in this mode, we keep the backbone but attach new output heads specified in head config + assert ( + FTConfig.HEADS in self.config + ), "heads cannot be empty when using RETAIN_BACKBONE_ONLY mode!" + + def load_model(self) -> nn.Module: + # if provided a hydra config to start, build from the starting hydra model + # this assumes the weights are loaded from the state_dict in the checkpoint.pt file instead + # so no actual weights are loaded here + if FTConfig.STARTING_MODEL in self.config: + # register model from hydra_config + config_copy = copy.deepcopy(self.config[FTConfig.STARTING_MODEL]) + name = config_copy.pop("name") + hydra_model = registry.get_model_class(name)(**config_copy) + # if provided a checkpoint to start then load the model and weights from the given checkpoint + # this happens used in the beginning of a finetuning run + elif FTConfig.STARTING_CHECKPOINT in self.config: + hydra_model: HydraInterface = load_hydra_model( + self.config[FTConfig.STARTING_CHECKPOINT] + ) + assert isinstance(hydra_model, HydraInterface) + + num_params = sum(p.numel() for p in hydra_model.parameters()) + logging.info(f"Loaded Original hydra model with {num_params} params") + return hydra_model + + def get_standalone_config(self) -> dict: + # replace a config with a checkpoint with one that has the model config only + # this is required for standalone prediction (so we don't need to ship the original checkpoint), + # multi-round finetuning, and better robustness + standalone_config = { + "name": FTHYDRA_NAME, + FTConfig.FT_CONFIG_NAME: self.config, + } + if FTConfig.STARTING_CHECKPOINT in self.config: + # modify the config to store the original model config inside model attrs + # so we dont need the checkpoint again when loading from checkpoint + new_config = copy.deepcopy(self.config) + new_config[FTConfig.STARTING_MODEL] = ( + get_model_config_from_checkpoint( + self.config[FTConfig.STARTING_CHECKPOINT] + ) + ) + standalone_config[FTConfig.FT_CONFIG_NAME] = new_config + return standalone_config + + @property + def mode(self) -> FineTuneMode: + return self._mode + + @property + def head_config(self) -> dict: + return copy.deepcopy(self.config[FTConfig.HEADS]) + + +@registry.register_model(FTHYDRA_NAME) +class FineTuneHydra(nn.Module, HydraInterface): + def __init__(self, finetune_config: dict): + super().__init__() + ft_config = FTConfig(finetune_config) + logging.info(f"Initializing FineTuneHydra model in {ft_config.mode} mode") + hydra_model: HydraInterface = ft_config.load_model() + self.backbone: BackboneInterface = hydra_model.get_backbone() + + if ft_config.mode == FineTuneMode.DATA_ONLY: + # in this mode, we just use the model as is and train on it with new data + self.output_heads: dict[str, HeadInterface] = hydra_model.get_heads() + elif ft_config.mode == FineTuneMode.RETAIN_BACKBONE_ONLY: + # in this mode, we keep the backbone but attach new output heads specified in head config + self.output_heads: dict[str, HeadInterface] = {} + heads_config = ft_config.head_config + head_names_sorted = sorted(heads_config.keys()) + for head_name in head_names_sorted: + head_config = heads_config[head_name] + if "module" not in head_config: + raise ValueError( + f"{head_name} head does not specify module to use for the head" + ) + + module_name = head_config.pop("module") + self.output_heads[head_name] = registry.get_model_class(module_name)( + self.backbone, + **head_config, + ) + num_params = sum( + p.numel() for p in self.output_heads[head_name].parameters() + ) + logging.info( + f"Attaching new output head: {module_name} with {num_params} params" + ) + self.output_heads = torch.nn.ModuleDict(self.output_heads) + + + def forward(self, data: Batch): + emb = self.backbone(data) + out = {} + for k in self.output_heads: + out.update(self.output_heads[k](data, emb)) + return out + + def get_backbone(self) -> BackboneInterface: + return self.backbone + + def get_heads(self) -> dict[str, HeadInterface]: + return self.output_heads diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 40c7e65de..850787902 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -36,10 +36,12 @@ get_commit_hash, get_loss_module, load_state_dict, + match_state_dict, save_checkpoint, update_config, ) from fairchem.core.datasets.base_dataset import create_dataset +from fairchem.core.models.finetune_hydra import FineTuneHydra, FTConfig from fairchem.core.modules.evaluator import Evaluator from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage from fairchem.core.modules.loss import DDPLoss @@ -112,8 +114,7 @@ def __init__( self.config = { "task": task, "trainer": name, - "model": aii(model.pop("name"), str), - "model_attributes": model, + "model": model, "outputs": outputs, "optim": optimizer, "loss_functions": loss_functions, @@ -292,9 +293,7 @@ def get_dataloader(self, dataset, sampler) -> DataLoader: ) def load_datasets(self) -> None: - self.ocp_collater = OCPCollater( - self.config["model_attributes"].get("otf_graph", False) - ) + self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", False)) self.train_loader = None self.val_loader = None self.test_loader = None @@ -506,16 +505,20 @@ def load_task(self): def load_model(self) -> None: # Build model if distutils.is_master(): - logging.info(f"Loading model: {self.config['model']}") + logging.info(f"Loading model: {self.config['model']['name']}") - self.model = registry.get_model_class(self.config["model"])( - **self.config["model_attributes"], + model_config_copy = copy.deepcopy(self.config["model"]) + model_name = model_config_copy.pop("name") + self.model = registry.get_model_class(model_name)( + **model_config_copy, ).to(self.device) + num_params = sum(p.numel() for p in self.model.parameters()) + if distutils.is_master(): logging.info( f"Loaded {self.model.__class__.__name__} with " - f"{self.model.num_params} parameters." + f"{num_params} parameters." ) if self.logger is not None: @@ -525,11 +528,12 @@ def load_model(self) -> None: self.logger.watch( self.model, log_freq=int(self.config["logger"]["watch"]) ) - self.logger.log_summary({"num_params": self.model.num_params}) + self.logger.log_summary({"num_params": num_params}) if distutils.initialized() and not self.config["noddp"]: self.model = DistributedDataParallel( - self.model, device_ids=None if self.cpu else [self.device] + self.model, + device_ids=None if self.cpu else [self.device], ) @property @@ -556,28 +560,8 @@ def load_checkpoint( self.best_val_metric = checkpoint.get("best_val_metric", None) self.primary_metric = checkpoint.get("primary_metric", None) - # Match the "module." count in the keys of model and checkpoint state_dict - # DataParallel model has 1 "module.", DistributedDataParallel has 2 "module." - # Not using either of the above two would have no "module." - - ckpt_key_count = next(iter(checkpoint["state_dict"])).count("module") - mod_key_count = next(iter(self.model.state_dict())).count("module") - key_count_diff = mod_key_count - ckpt_key_count - - if key_count_diff > 0: - new_dict = { - key_count_diff * "module." + k: v - for k, v in checkpoint["state_dict"].items() - } - elif key_count_diff < 0: - new_dict = { - k[len("module.") * abs(key_count_diff) :]: v - for k, v in checkpoint["state_dict"].items() - } - else: - new_dict = checkpoint["state_dict"] - - strict = self.config["task"].get("strict_load", True) + new_dict = match_state_dict(self.model.state_dict(), checkpoint["state_dict"]) + strict = self.config.get("task", {}).get("strict_load", True) load_state_dict(self.model, new_dict, strict=strict) if "optimizer" in checkpoint: @@ -718,6 +702,13 @@ def save( training_state: bool = True, ) -> str | None: if not self.is_debug and distutils.is_master(): + # if we are using a FineTune-able model, then we need to modify the config to remove + # the original starting checkpoint so it can be loaded standalone, can move this to save function + if isinstance(self.model, FineTuneHydra): + self.config["model"] = FTConfig( + self.config["model"][FTConfig.FT_CONFIG_NAME] + ).get_standalone_config() + state = { "state_dict": self.model.state_dict(), "normalizers": { diff --git a/tests/core/datasets/test_create_dataset.py b/tests/core/datasets/test_create_dataset.py index d90271c53..1dc17bcb3 100644 --- a/tests/core/datasets/test_create_dataset.py +++ b/tests/core/datasets/test_create_dataset.py @@ -36,7 +36,7 @@ def get_dataloader(self, *args, **kwargs): return None config = { - "model_attributes": {}, + "model": {}, "optim": {"batch_size": 0}, "dataset": { "format": "ase_db", diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py new file mode 100644 index 000000000..a171a9689 --- /dev/null +++ b/tests/core/e2e/test_e2e_commons.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import collections.abc +import glob +import os +from pathlib import Path + +import yaml +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + +from fairchem.core._cli import Runner +from fairchem.core.common.flags import flags +from fairchem.core.common.test_utils import ( + PGConfig, + init_env_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.common.utils import build_config + + +def oc20_lmdb_train_and_val_from_paths( + train_src, val_src, test_src=None, otf_norms=False +): + datasets = {} + if train_src is not None: + datasets["train"] = { + "src": train_src, + "format": "lmdb", + "key_mapping": {"y": "energy", "force": "forces"}, + } + if otf_norms is True: + datasets["train"].update( + { + "transforms": { + "element_references": { + "fit": { + "targets": ["energy"], + "batch_size": 4, + "num_batches": 10, + "driver": "gelsd", + } + }, + "normalizer": { + "fit": { + "targets": {"energy": None, "forces": {"mean": 0.0}}, + "batch_size": 4, + "num_batches": 10, + } + }, + } + } + ) + else: + datasets["train"].update( + { + "transforms": { + "normalizer": { + "energy": { + "mean": -0.7554450631141663, + "stdev": 2.887317180633545, + }, + "forces": {"mean": 0.0, "stdev": 2.887317180633545}, + } + } + } + ) + if val_src is not None: + datasets["val"] = {"src": val_src, "format": "lmdb"} + if test_src is not None: + datasets["test"] = {"src": test_src, "format": "lmdb"} + return datasets + + +def get_tensorboard_log_files(logdir): + return glob.glob(f"{logdir}/tensorboard/*/events.out*") + + +def get_tensorboard_log_values(logdir): + tf_event_files = get_tensorboard_log_files(logdir) + assert len(tf_event_files) == 1 + tf_event_file = tf_event_files[0] + acc = EventAccumulator(tf_event_file) + acc.Reload() + return acc + + +def merge_dictionary(d, u): + for k, v in u.items(): + if isinstance(v, collections.abc.Mapping): + d[k] = merge_dictionary(d.get(k, {}), v) + else: + d[k] = v + return d + +def _run_main( + rundir, + input_yaml, + update_dict_with=None, + update_run_args_with=None, + save_checkpoint_to=None, + save_predictions_to=None, + world_size=0, +): + config_yaml = Path(rundir) / "train_and_val_on_val.yml" + + with open(input_yaml) as yaml_file: + yaml_config = yaml.safe_load(yaml_file) + if update_dict_with is not None: + yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" + with open(str(config_yaml), "w") as yaml_file: + yaml.dump(yaml_config, yaml_file) + run_args = { + "run_dir": rundir, + "logdir": f"{rundir}/logs", + "config_yml": config_yaml, + } + if update_run_args_with is not None: + run_args.update(update_run_args_with) + + # run + parser = flags.get_parser() + args, override_args = parser.parse_known_args( + ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"] + ) + for arg_name, arg_value in run_args.items(): + setattr(args, arg_name, arg_value) + config = build_config(args, override_args) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False + ) + spawn_multi_process( + pg_config, + Runner(distributed=True), + init_env_rank_and_launch_test, + config, + ) + else: + Runner()(config) + + if save_checkpoint_to is not None: + checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") + assert len(checkpoints) == 1 + os.rename(checkpoints[0], save_checkpoint_to) + if save_predictions_to is not None: + predictions_filenames = glob.glob(f"{rundir}/results/*/s2ef_predictions.npz") + assert len(predictions_filenames) == 1 + os.rename(predictions_filenames[0], save_predictions_to) + return get_tensorboard_log_values( + f"{rundir}/logs", + ) diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py new file mode 100644 index 000000000..91f2abd49 --- /dev/null +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import os +import tempfile +from pathlib import Path + +import pytest +import torch +from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths + +from fairchem.core.models.finetune_hydra import FTHYDRA_NAME, FineTuneMode, FTConfig + + +@pytest.fixture() +def tutorial_val_src(tutorial_dataset_path): + return tutorial_dataset_path / "s2ef/val_20" + + +def make_checkpoint(tempdir: str, data_source: Path, seed: int) -> str: + # first train a tiny eqv2 model to get a checkpoint + eqv2_yml = Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml") + ck_path = os.path.join(tempdir, "checkpoint.pt") + _run_main( + tempdir, + eqv2_yml, + update_dict_with={ + "optim": { + "max_epochs": 1, + "eval_every": 8, + "batch_size": 1, + "num_workers": 0, + }, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(data_source), + val_src=str(data_source), + test_src=str(data_source), + otf_norms=False, + ), + }, + update_run_args_with={"seed": seed}, + save_checkpoint_to=ck_path, + world_size=0, + ) + assert os.path.isfile(ck_path) + return ck_path + + +def run_main_with_ft_hydra(tempdir: str, + yaml: str, + data_src: str, + run_args: dict, + ft_config: str, + output_checkpoint: str): + _run_main( + tempdir, + yaml, + update_dict_with={ + "optim": { + "max_epochs": 1, + "eval_every": 8, + "batch_size": 1, + "num_workers": 0, + "lr_initial": 0.0 # don't learn anything + }, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(data_src), + val_src=str(data_src), + test_src=str(data_src), + otf_norms=False, + ), + "model": { + "name": FTHYDRA_NAME, + FTConfig.FT_CONFIG_NAME: ft_config, + } + }, + update_run_args_with=run_args, + save_checkpoint_to=output_checkpoint, + world_size=0, + ) + + +def test_finetune_hydra_retain_backbone(tutorial_val_src): + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + old_state_dict = torch.load(starting_ckpt)["state_dict"] + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + ft_config = { + "mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name, + "starting_checkpoint": starting_ckpt, + "heads": { + "energy": { + "module": "equiformer_v2_energy_head" + }, + "forces": { + "module": "equiformer_v2_force_head" + } + } + } + run_main_with_ft_hydra(tempdir = ft_temp_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config, + output_checkpoint = ck_ft_path) + assert os.path.isfile(ck_ft_path) + ft_ckpt = torch.load(ck_ft_path) + assert "config" in ft_ckpt + assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME + # check that the backbone weights are the same, and other weights are not the same + new_state_dict = ft_ckpt["state_dict"] + for key in new_state_dict: + if key.startswith("backbone"): + # backbone should be identical + assert torch.allclose(new_state_dict[key], old_state_dict[key]) + elif key.startswith("output_heads") and key.endswith("weight"): + # heads weight should be different because the seeds are different + assert not torch.allclose(new_state_dict[key], old_state_dict[key]) + + +def test_finetune_hydra_data_only(tutorial_val_src): + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + old_state_dict = torch.load(starting_ckpt)["state_dict"] + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + ft_config = { + "mode": FineTuneMode.DATA_ONLY.name, + "starting_checkpoint": starting_ckpt, + } + run_main_with_ft_hydra(tempdir = ft_temp_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config, + output_checkpoint = ck_ft_path) + assert os.path.isfile(ck_ft_path) + ft_ckpt = torch.load(ck_ft_path) + assert "config" in ft_ckpt + config_model = ft_ckpt["config"]["model"] + assert config_model["name"] == FTHYDRA_NAME + # check that the entire model weights are the same + new_state_dict = ft_ckpt["state_dict"] + assert len(new_state_dict) == len(old_state_dict) + for key in new_state_dict: + assert torch.allclose(new_state_dict[key], old_state_dict[key]) + # check the new checkpoint contains a hydra model + assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME] + + +def test_finetune_from_finetunehydra(tutorial_val_src): + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as finetune_run1_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt") + ft_config_1 = { + "mode": FineTuneMode.DATA_ONLY.name, + "starting_checkpoint": starting_ckpt, + } + run_main_with_ft_hydra(tempdir = finetune_run1_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config_1, + output_checkpoint = ck_ft_path) + assert os.path.isfile(ck_ft_path) + + # now that we have a second checkpoint, try finetuning again from this checkpoint + ######################################################################################## + with tempfile.TemporaryDirectory() as finetune_run2_dir: + ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt") + ft_config_2 = { + "mode": FineTuneMode.DATA_ONLY.name, + "starting_checkpoint": ck_ft_path, + } + run_main_with_ft_hydra(tempdir = finetune_run2_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config_2, + output_checkpoint = ck_ft2_path) + ft_ckpt2 = torch.load(ck_ft2_path) + assert "config" in ft_ckpt2 + config_model = ft_ckpt2["config"]["model"] + assert config_model["name"] == FTHYDRA_NAME + old_state_dict = torch.load(ck_ft_path)["state_dict"] + new_state_dict = ft_ckpt2["state_dict"] + # the state dicts should still be identical because we made the LR = 0.0 + for key in new_state_dict: + assert torch.allclose(new_state_dict[key], old_state_dict[key]) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 6c773d32e..695fb537d 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections.abc import glob import os import tempfile @@ -9,17 +8,9 @@ import numpy as np import numpy.testing as npt import pytest -import yaml -from tensorboard.backend.event_processing.event_accumulator import EventAccumulator +from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths -from fairchem.core._cli import Runner -from fairchem.core.common.flags import flags -from fairchem.core.common.test_utils import ( - PGConfig, - init_env_rank_and_launch_test, - spawn_multi_process, -) -from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.common.utils import setup_logging from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes setup_logging() @@ -67,142 +58,6 @@ def tutorial_val_src(tutorial_dataset_path): return tutorial_dataset_path / "s2ef/val_20" -def oc20_lmdb_train_and_val_from_paths( - train_src, val_src, test_src=None, otf_norms=False -): - datasets = {} - if train_src is not None: - datasets["train"] = { - "src": train_src, - "format": "lmdb", - "key_mapping": {"y": "energy", "force": "forces"}, - } - if otf_norms is True: - datasets["train"].update( - { - "transforms": { - "element_references": { - "fit": { - "targets": ["energy"], - "batch_size": 4, - "num_batches": 10, - "driver": "gelsd", - } - }, - "normalizer": { - "fit": { - "targets": {"energy": None, "forces": {"mean": 0.0}}, - "batch_size": 4, - "num_batches": 10, - } - }, - } - } - ) - else: - datasets["train"].update( - { - "transforms": { - "normalizer": { - "energy": { - "mean": -0.7554450631141663, - "stdev": 2.887317180633545, - }, - "forces": {"mean": 0.0, "stdev": 2.887317180633545}, - } - } - } - ) - if val_src is not None: - datasets["val"] = {"src": val_src, "format": "lmdb"} - if test_src is not None: - datasets["test"] = {"src": test_src, "format": "lmdb"} - return datasets - - -def get_tensorboard_log_files(logdir): - return glob.glob(f"{logdir}/tensorboard/*/events.out*") - - -def get_tensorboard_log_values(logdir): - tf_event_files = get_tensorboard_log_files(logdir) - assert len(tf_event_files) == 1 - tf_event_file = tf_event_files[0] - acc = EventAccumulator(tf_event_file) - acc.Reload() - return acc - - -def merge_dictionary(d, u): - for k, v in u.items(): - if isinstance(v, collections.abc.Mapping): - d[k] = merge_dictionary(d.get(k, {}), v) - else: - d[k] = v - return d - - -def _run_main( - rundir, - input_yaml, - update_dict_with=None, - update_run_args_with=None, - save_checkpoint_to=None, - save_predictions_to=None, - world_size=0, -): - config_yaml = Path(rundir) / "train_and_val_on_val.yml" - - with open(input_yaml) as yaml_file: - yaml_config = yaml.safe_load(yaml_file) - if update_dict_with is not None: - yaml_config = merge_dictionary(yaml_config, update_dict_with) - yaml_config["backend"] = "gloo" - with open(str(config_yaml), "w") as yaml_file: - yaml.dump(yaml_config, yaml_file) - run_args = { - "run_dir": rundir, - "logdir": f"{rundir}/logs", - "config_yml": config_yaml, - } - if update_run_args_with is not None: - run_args.update(update_run_args_with) - - # run - parser = flags.get_parser() - args, override_args = parser.parse_known_args( - ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"] - ) - for arg_name, arg_value in run_args.items(): - setattr(args, arg_name, arg_value) - config = build_config(args, override_args) - - if world_size > 0: - pg_config = PGConfig( - backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False - ) - spawn_multi_process( - pg_config, - Runner(distributed=True), - init_env_rank_and_launch_test, - config, - ) - else: - Runner()(config) - - if save_checkpoint_to is not None: - checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") - assert len(checkpoints) == 1 - os.rename(checkpoints[0], save_checkpoint_to) - if save_predictions_to is not None: - predictions_filenames = glob.glob(f"{rundir}/results/*/s2ef_predictions.npz") - assert len(predictions_filenames) == 1 - os.rename(predictions_filenames[0], save_predictions_to) - return get_tensorboard_log_values( - f"{rundir}/logs", - ) - - """ These tests are intended to be as quick as possible and test only that the network is runnable and outputs training+validation to tensorboard output These should catch errors such as shape mismatches or otherways to code wise break a network diff --git a/tests/core/models/test_configs/test_finetune_hydra.yml b/tests/core/models/test_configs/test_finetune_hydra.yml new file mode 100644 index 000000000..a5f1dc51b --- /dev/null +++ b/tests/core/models/test_configs/test_finetune_hydra.yml @@ -0,0 +1,55 @@ +trainer: forces + +outputs: + energy: + shape: 1 + level: system + forces: + irrep_dim: 1 + level: atom + train_on_free_atoms: True + eval_on_free_atoms: True + +loss_functions: + - energy: + fn: mae + coefficient: 2 + - forces: + fn: l2mae + coefficient: 100 + +evaluation_metrics: + metrics: + energy: + - mae + forces: + - mae + - cosine_similarity + - magnitude_error + misc: + - energy_forces_within_threshold + primary_metric: forces_mae + +logger: + name: tensorboard + +model: + name: finetune_hydra + finetune_config: {} + + +optim: + batch_size: 5 + eval_batch_size: 2 + num_workers: 0 + lr_initial: 0.0025 + optimizer: AdamW + optimizer_params: {"amsgrad": True,weight_decay: 0.0} + eval_every: 190 + max_epochs: 50 + force_coefficient: 20 + scheduler: "Null" + energy_coefficient: 1 + clip_grad_norm: 20 + loss_energy: mae + loss_force: l2mae