diff --git a/src/fairchem/core/models/base.py b/src/fairchem/core/models/base.py index ab3c95afa..32865e0ef 100644 --- a/src/fairchem/core/models/base.py +++ b/src/fairchem/core/models/base.py @@ -242,6 +242,7 @@ def __init__( finetune_config: dict | None = None, otf_graph: bool = True, pass_through_head_outputs: bool = False, + freeze_backbone: bool = False, ): super().__init__() self.device = None @@ -282,6 +283,10 @@ def __init__( "Backbone not specified and not found in the starting checkpoint" ) + if freeze_backbone: + for param in self.backbone.parameters(): + param.requires_grad = False + if heads is not None: heads = copy.deepcopy(heads) # Iterate through outputs_cfg and create heads diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index 9a36e09ef..4dc2e2efc 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -5,10 +5,11 @@ from pathlib import Path import pytest -from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint import torch from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths +from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint + @pytest.fixture() def tutorial_val_src(tutorial_dataset_path): @@ -104,12 +105,122 @@ def verify_release_checkpoint(release_yaml_fn, release_checkpoint_fn, ft_state_d assert os.path.isfile(ck_release_ft_afterload_path) ft_after_state_dict = torch.load(ck_release_ft_afterload_path)["state_dict"] for key in ft_after_state_dict: - if key.startswith("module.backbone"): - assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key]) - elif key.startswith("module.output_heads") and key.endswith("weight"): + if ( + key.startswith("module.backbone") + or key.startswith("module.output_heads") + and key.endswith("weight") + ): assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key]) +def test_finetune_hydra_freeze_backbone(tutorial_val_src): + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + old_state_dict = torch.load(starting_ckpt)["state_dict"] + + # Test to make sure without freeze the backbone weights change + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + model_config = { + "name": "hydra", + "finetune_config": {"starting_checkpoint": starting_ckpt}, + "heads": { + "energy": {"module": "equiformer_v2_energy_head"}, + "forces": {"module": "equiformer_v2_force_head"}, + }, + } + + _run_main( + ft_temp_dir, + ft_yml, + update_dict_with={ + "optim": { + "max_epochs": 1, + "eval_every": 8, + "batch_size": 1, + "num_workers": 0, + "lr_initial": 10.0, + }, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + otf_norms=False, + ), + "model": model_config, + }, + update_run_args_with={"seed": 1000}, + save_checkpoint_to=ck_ft_path, + world_size=1, + ) + + assert os.path.isfile(ck_ft_path) + ft_ckpt = torch.load(ck_ft_path) + assert "config" in ft_ckpt + assert ft_ckpt["config"]["model"]["name"] == "hydra" + # check that the backbone weights are different, and other weights are not the same + ft_state_dict = ft_ckpt["state_dict"] + for key in ft_state_dict: + if key.startswith("module.backbone") and ".weight" in key: + # backbone should be different + assert not torch.allclose(ft_state_dict[key], old_state_dict[key]) + elif key.startswith("module.output_heads") and key.endswith("weight"): + # heads weight should be different because the seeds are different + assert not torch.allclose(ft_state_dict[key], old_state_dict[key]) + + # Test to make sure with freeze the backbone weights are unchanged + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + model_config = { + "name": "hydra", + "finetune_config": {"starting_checkpoint": starting_ckpt}, + "heads": { + "energy": {"module": "equiformer_v2_energy_head"}, + "forces": {"module": "equiformer_v2_force_head"}, + }, + "freeze_backbone": True, + } + + _run_main( + ft_temp_dir, + ft_yml, + update_dict_with={ + "optim": { + "max_epochs": 1, + "eval_every": 8, + "batch_size": 1, + "num_workers": 0, + }, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + otf_norms=False, + ), + "model": model_config, + }, + update_run_args_with={"seed": 1000}, + save_checkpoint_to=ck_ft_path, + world_size=1, + ) + + assert os.path.isfile(ck_ft_path) + ft_ckpt = torch.load(ck_ft_path) + assert "config" in ft_ckpt + assert ft_ckpt["config"]["model"]["name"] == "hydra" + # check that the backbone weights are different, and other weights are not the same + ft_state_dict = ft_ckpt["state_dict"] + for key in ft_state_dict: + if key.startswith("module.backbone"): + # backbone should be different + assert torch.allclose(ft_state_dict[key], old_state_dict[key]) + elif key.startswith("module.output_heads") and key.endswith("weight"): + # heads weight should be different because the seeds are different + assert not torch.allclose(ft_state_dict[key], old_state_dict[key]) + + def test_finetune_hydra_retain_backbone(tutorial_val_src): with tempfile.TemporaryDirectory() as orig_ckpt_dir: starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)