Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune Hydra #797

Merged
merged 19 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,35 @@ def _report_incompat_keys(
return missing_keys, unexpected_keys


def match_state_dict(
model_state_dict: Mapping[str, torch.Tensor],
checkpoint_state_dict: Mapping[str, torch.Tensor],
) -> dict:
# match the model's state dict with the checkpoint state and return a new dict
# that's compatible with the models

# Match the "module." count in the keys of model and checkpoint state_dict
# DataParallel model has 1 "module.", DistributedDataParallel has 2 "module."
# Not using either of the above two would have no "module."

ckpt_key_count = next(iter(checkpoint_state_dict)).count("module")
mod_key_count = next(iter(model_state_dict)).count("module")
key_count_diff = mod_key_count - ckpt_key_count

if key_count_diff > 0:
new_dict = {
key_count_diff * "module." + k: v for k, v in checkpoint_state_dict.items()
}
elif key_count_diff < 0:
new_dict = {
k[len("module.") * abs(key_count_diff) :]: v
for k, v in checkpoint_state_dict.items()
}
else:
new_dict = checkpoint_state_dict
return new_dict


def load_state_dict(
module: nn.Module,
state_dict: Mapping[str, torch.Tensor],
Expand Down
25 changes: 23 additions & 2 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from __future__ import annotations

import copy
import logging
from abc import ABCMeta, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -227,8 +228,19 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
return


class HydraInterface(ABC):
rayg1234 marked this conversation as resolved.
Show resolved Hide resolved
# a hydra has a backbone and heads
@abstractmethod
def get_backbone(self) -> BackboneInterface:
raise not NotImplementedError

@abstractmethod
def get_heads(self) -> dict[str, HeadInterface]:
raise not NotImplementedError


@registry.register_model("hydra")
class HydraModel(nn.Module, GraphModelMixin):
class HydraModel(nn.Module, GraphModelMixin, HydraInterface):
def __init__(
self,
backbone: dict,
Expand All @@ -237,6 +249,9 @@ def __init__(
):
super().__init__()
self.otf_graph = otf_graph
# make a copy so we don't modify the original config
backbone = copy.deepcopy(backbone)
heads = copy.deepcopy(heads)

backbone_model_name = backbone.pop("model")
self.backbone: BackboneInterface = registry.get_model_class(
Expand Down Expand Up @@ -272,3 +287,9 @@ def forward(self, data: Batch):
out.update(self.output_heads[k](data, emb))

return out

def get_backbone(self) -> BackboneInterface:
return self.backbone

def get_heads(self) -> dict[str, HeadInterface]:
return self.output_heads
rayg1234 marked this conversation as resolved.
Show resolved Hide resolved
177 changes: 177 additions & 0 deletions src/fairchem/core/models/finetune_hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import copy
import errno
import logging
import os
from enum import Enum
from typing import TYPE_CHECKING

import torch
from torch import nn

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import load_state_dict, match_state_dict
from fairchem.core.models.base import BackboneInterface, HeadInterface, HydraInterface

if TYPE_CHECKING:
from torch_geometric.data import Batch

FTHYDRA_NAME = "finetune_hydra"

class FineTuneMode(Enum):
# in DATA_ONLY, we load the entire model and only finetune on new data
DATA_ONLY = 1
# in this mode, we only load the Backbone and feed the output of the backbone
# to new heads that are specified
RETAIN_BACKBONE_ONLY = 2


def get_model_config_from_checkpoint(checkpoint_path: str) -> dict:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
checkpoint = torch.load(checkpoint_path)
return checkpoint["config"]["model"]


def load_hydra_model(checkpoint_path: str) -> HydraInterface:
if not os.path.isfile(checkpoint_path):
rayg1234 marked this conversation as resolved.
Show resolved Hide resolved
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
config = checkpoint["config"]["model"]
name = config.pop("name")
hydra_model = registry.get_model_class(name)(**config)
assert isinstance(
hydra_model, HydraInterface
), "Can only load models with the HydraInterface"
matched_dict = match_state_dict(hydra_model.state_dict(), checkpoint["state_dict"])
load_state_dict(hydra_model, matched_dict, strict=True)
return hydra_model


class FTConfig:
FT_CONFIG_NAME = "finetune_config"
STARTING_CHECKPOINT = "starting_checkpoint"
STARTING_MODEL = "starting_model"
MODE = "mode"
HEADS = "heads"

def __init__(self, config: dict):
self.config = config
self._mode = FineTuneMode[self.config[FTConfig.MODE]]
assert (
(FTConfig.STARTING_CHECKPOINT in self.config)
or (FTConfig.STARTING_MODEL in self.config)
), "Either a starting checkpoint or a starting model must be provided!"
assert FTConfig.MODE in self.config
if self._mode == FineTuneMode.RETAIN_BACKBONE_ONLY:
# in this mode, we keep the backbone but attach new output heads specified in head config
assert (
FTConfig.HEADS in self.config
), "heads cannot be empty when using RETAIN_BACKBONE_ONLY mode!"

def load_model(self) -> nn.Module:
# if provided a hydra config to start, build from the starting hydra model
# this assumes the weights are loaded from the state_dict in the checkpoint.pt file instead
# so no actual weights are loaded here
if FTConfig.STARTING_MODEL in self.config:
# register model from hydra_config
config_copy = copy.deepcopy(self.config[FTConfig.STARTING_MODEL])
name = config_copy.pop("name")
hydra_model = registry.get_model_class(name)(**config_copy)
# if provided a checkpoint to start then load the model and weights from the given checkpoint
# this happens used in the beginning of a finetuning run
elif FTConfig.STARTING_CHECKPOINT in self.config:
hydra_model: HydraInterface = load_hydra_model(
self.config[FTConfig.STARTING_CHECKPOINT]
)
assert isinstance(hydra_model, HydraInterface)

num_params = sum(p.numel() for p in hydra_model.parameters())
logging.info(f"Loaded Original hydra model with {num_params} params")
rayg1234 marked this conversation as resolved.
Show resolved Hide resolved
return hydra_model

def get_standalone_config(self) -> dict:
# replace a config with a checkpoint with one that has the model config only
# this is required for standalone prediction (so we don't need to ship the original checkpoint),
# multi-round finetuning, and better robustness
standalone_config = {
"name": FTHYDRA_NAME,
FTConfig.FT_CONFIG_NAME: self.config,
}
if FTConfig.STARTING_CHECKPOINT in self.config:
# modify the config to store the original model config inside model attrs
# so we dont need the checkpoint again when loading from checkpoint
new_config = copy.deepcopy(self.config)
new_config[FTConfig.STARTING_MODEL] = (
get_model_config_from_checkpoint(
self.config[FTConfig.STARTING_CHECKPOINT]
)
)
standalone_config[FTConfig.FT_CONFIG_NAME] = new_config
return standalone_config

@property
def mode(self) -> FineTuneMode:
return self._mode

@property
def head_config(self) -> dict:
return copy.deepcopy(self.config[FTConfig.HEADS])


@registry.register_model(FTHYDRA_NAME)
class FineTuneHydra(nn.Module, HydraInterface):
def __init__(self, finetune_config: dict):
super().__init__()
ft_config = FTConfig(finetune_config)
logging.info(f"Initializing FineTuneHydra model in {ft_config.mode} mode")
hydra_model: HydraInterface = ft_config.load_model()
self.backbone: BackboneInterface = hydra_model.get_backbone()

if ft_config.mode == FineTuneMode.DATA_ONLY:
# in this mode, we just use the model as is and train on it with new data
self.output_heads: dict[str, HeadInterface] = hydra_model.get_heads()
elif ft_config.mode == FineTuneMode.RETAIN_BACKBONE_ONLY:
# in this mode, we keep the backbone but attach new output heads specified in head config
self.output_heads: dict[str, HeadInterface] = {}
heads_config = ft_config.head_config
head_names_sorted = sorted(heads_config.keys())
for head_name in head_names_sorted:
head_config = heads_config[head_name]
if "module" not in head_config:
raise ValueError(
f"{head_name} head does not specify module to use for the head"
)

module_name = head_config.pop("module")
self.output_heads[head_name] = registry.get_model_class(module_name)(
self.backbone,
**head_config,
)
num_params = sum(
p.numel() for p in self.output_heads[head_name].parameters()
)
logging.info(
f"Attaching new output head: {module_name} with {num_params} params"
)
self.output_heads = torch.nn.ModuleDict(self.output_heads)


def forward(self, data: Batch):
emb = self.backbone(data)
out = {}
for k in self.output_heads:
out.update(self.output_heads[k](data, emb))
return out

def get_backbone(self) -> BackboneInterface:
return self.backbone

def get_heads(self) -> dict[str, HeadInterface]:
return self.output_heads
57 changes: 24 additions & 33 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@
get_commit_hash,
get_loss_module,
load_state_dict,
match_state_dict,
save_checkpoint,
update_config,
)
from fairchem.core.datasets.base_dataset import create_dataset
from fairchem.core.models.finetune_hydra import FineTuneHydra, FTConfig
from fairchem.core.modules.evaluator import Evaluator
from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage
from fairchem.core.modules.loss import DDPLoss
Expand Down Expand Up @@ -112,8 +114,7 @@ def __init__(
self.config = {
"task": task,
"trainer": name,
"model": aii(model.pop("name"), str),
"model_attributes": model,
rayg1234 marked this conversation as resolved.
Show resolved Hide resolved
"model": model,
"outputs": outputs,
"optim": optimizer,
"loss_functions": loss_functions,
Expand Down Expand Up @@ -292,9 +293,7 @@ def get_dataloader(self, dataset, sampler) -> DataLoader:
)

def load_datasets(self) -> None:
self.ocp_collater = OCPCollater(
self.config["model_attributes"].get("otf_graph", False)
)
self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", False))
self.train_loader = None
self.val_loader = None
self.test_loader = None
Expand Down Expand Up @@ -506,16 +505,20 @@ def load_task(self):
def load_model(self) -> None:
# Build model
if distutils.is_master():
logging.info(f"Loading model: {self.config['model']}")
logging.info(f"Loading model: {self.config['model']['name']}")

self.model = registry.get_model_class(self.config["model"])(
**self.config["model_attributes"],
model_config_copy = copy.deepcopy(self.config["model"])
model_name = model_config_copy.pop("name")
self.model = registry.get_model_class(model_name)(
**model_config_copy,
).to(self.device)

num_params = sum(p.numel() for p in self.model.parameters())

if distutils.is_master():
logging.info(
f"Loaded {self.model.__class__.__name__} with "
f"{self.model.num_params} parameters."
f"{num_params} parameters."
)

if self.logger is not None:
Expand All @@ -525,11 +528,12 @@ def load_model(self) -> None:
self.logger.watch(
self.model, log_freq=int(self.config["logger"]["watch"])
)
self.logger.log_summary({"num_params": self.model.num_params})
self.logger.log_summary({"num_params": num_params})

if distutils.initialized() and not self.config["noddp"]:
self.model = DistributedDataParallel(
self.model, device_ids=None if self.cpu else [self.device]
self.model,
device_ids=None if self.cpu else [self.device],
)

@property
Expand All @@ -556,28 +560,8 @@ def load_checkpoint(
self.best_val_metric = checkpoint.get("best_val_metric", None)
self.primary_metric = checkpoint.get("primary_metric", None)

# Match the "module." count in the keys of model and checkpoint state_dict
# DataParallel model has 1 "module.", DistributedDataParallel has 2 "module."
# Not using either of the above two would have no "module."

ckpt_key_count = next(iter(checkpoint["state_dict"])).count("module")
mod_key_count = next(iter(self.model.state_dict())).count("module")
key_count_diff = mod_key_count - ckpt_key_count

if key_count_diff > 0:
new_dict = {
key_count_diff * "module." + k: v
for k, v in checkpoint["state_dict"].items()
}
elif key_count_diff < 0:
new_dict = {
k[len("module.") * abs(key_count_diff) :]: v
for k, v in checkpoint["state_dict"].items()
}
else:
new_dict = checkpoint["state_dict"]

strict = self.config["task"].get("strict_load", True)
rayg1234 marked this conversation as resolved.
Show resolved Hide resolved
new_dict = match_state_dict(self.model.state_dict(), checkpoint["state_dict"])
strict = self.config.get("task", {}).get("strict_load", True)
load_state_dict(self.model, new_dict, strict=strict)

if "optimizer" in checkpoint:
Expand Down Expand Up @@ -718,6 +702,13 @@ def save(
training_state: bool = True,
) -> str | None:
if not self.is_debug and distutils.is_master():
# if we are using a FineTune-able model, then we need to modify the config to remove
# the original starting checkpoint so it can be loaded standalone, can move this to save function
if isinstance(self.model, FineTuneHydra):
self.config["model"] = FTConfig(
self.config["model"][FTConfig.FT_CONFIG_NAME]
).get_standalone_config()

state = {
"state_dict": self.model.state_dict(),
"normalizers": {
Expand Down
2 changes: 1 addition & 1 deletion tests/core/datasets/test_create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_dataloader(self, *args, **kwargs):
return None

config = {
"model_attributes": {},
"model": {},
"optim": {"batch_size": 0},
"dataset": {
"format": "ase_db",
Expand Down
Loading