Skip to content

Commit

Permalink
add script to port old equiv2 checkpoint+yaml to hydra version (#846)
Browse files Browse the repository at this point in the history
* add script to port old equiv2 checkpoint+yaml to hydra version

* fix up comments

* lint

* move script and add forces to test
  • Loading branch information
misko authored Sep 13, 2024
1 parent 6ded0d3 commit 7d40b20
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 0 deletions.
88 changes: 88 additions & 0 deletions src/fairchem/core/models/equiformer_v2/eqv2_to_eqv2_hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import annotations

import logging
import os
from collections import OrderedDict
from copy import deepcopy

import torch
import yaml


def convert_checkpoint_and_config_to_hydra(
yaml_fn, checkpoint_fn, new_yaml_fn, new_checkpoint_fn
):
assert not os.path.exists(new_yaml_fn), "Output yaml cannot already exist!"
assert not os.path.exists(
new_checkpoint_fn
), "Output checkpoint cannot already exist!"

def eqv2_state_dict_to_hydra_state_dict(eqv2_state_dict):
hydra_state_dict = OrderedDict()
for og_key in list(eqv2_state_dict.keys()):
if "force_block" in og_key or "energy_block" in og_key:
key = og_key.replace(
"force_block", "output_heads.forces.force_block"
).replace("energy_block", "output_heads.energy.energy_block")
else:
offset = 0
if og_key[: len("module.")] == "module.":
offset += len("module.")
key = og_key[:offset] + "backbone." + og_key[offset:]
hydra_state_dict[key] = eqv2_state_dict[og_key]
return hydra_state_dict

def convert_configs_to_hydra(yaml_config, checkpoint_config):
new_model_config = {
"name": "hydra",
"backbone": checkpoint_config["model"].copy(),
"heads": {
"energy": {"module": "equiformer_v2_energy_head"},
"forces": {"module": "equiformer_v2_force_head"},
},
}
assert new_model_config["backbone"]["name"] in ["equiformer_v2"]
new_model_config["backbone"].pop("name")
new_model_config["backbone"]["model"] = "equiformer_v2_backbone"

# create a new checkpoint config
new_checkpoint_config = deepcopy(checkpoint_config)
new_checkpoint_config["model"] = new_model_config

# create a new YAML config
new_yaml_config = deepcopy(yaml_config)
new_yaml_config["model"] = new_model_config

for output_key, output_d in new_yaml_config["outputs"].items():
if output_d["level"] == "system":
output_d["property"] = "energy"
elif output_d["level"] == "atom":
output_d["property"] = "forces"
else:
logging.warning(
f"Converting output:{output_key} to new equiv2 hydra config \
failed to find level and could not set property in output correctly"
)

return new_yaml_config, new_checkpoint_config

# load existing from disk
with open(yaml_fn) as yaml_f:
yaml_config = yaml.safe_load(yaml_f)
checkpoint = torch.load(checkpoint_fn, map_location="cpu")

new_checkpoint = checkpoint.copy()
new_yaml_config, new_checkpoint_config = convert_configs_to_hydra(
yaml_config, checkpoint["config"]
)
new_checkpoint["config"] = new_checkpoint_config
new_checkpoint["state_dict"] = eqv2_state_dict_to_hydra_state_dict(
checkpoint["state_dict"]
)
for key in ["ema", "optimizer", "scheduler"]:
new_checkpoint.pop(key)

# write output
torch.save(new_checkpoint, new_checkpoint_fn)
with open(str(new_yaml_fn), "w") as yaml_file:
yaml.dump(new_yaml_config, yaml_file)
36 changes: 36 additions & 0 deletions src/fairchem/core/scripts/eqv2_to_hydra_eqv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import argparse

from fairchem.core.models.equiformer_v2.eqv2_to_eqv2_hydra import (
convert_checkpoint_and_config_to_hydra,
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--eqv2-checkpoint", help="path to eqv2 checkpoint", type=str, required=True
)
parser.add_argument(
"--eqv2-yaml", help="path to eqv2 yaml config", type=str, required=True
)
parser.add_argument(
"--hydra-eqv2-checkpoint",
help="path where to output hydra checkpoint",
type=str,
required=True,
)
parser.add_argument(
"--hydra-eqv2-yaml",
help="path where to output hydra yaml",
type=str,
required=True,
)
args = parser.parse_args()

convert_checkpoint_and_config_to_hydra(
yaml_fn=args.eqv2_yaml,
checkpoint_fn=args.eqv2_checkpoint,
new_yaml_fn=args.hydra_eqv2_yaml,
new_checkpoint_fn=args.hydra_eqv2_checkpoint,
)
113 changes: 113 additions & 0 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
update_yaml_with_dict,
)

from fairchem.core.models.equiformer_v2.eqv2_to_eqv2_hydra import (
convert_checkpoint_and_config_to_hydra,
)

from fairchem.core.common.flags import flags
from fairchem.core.common.utils import build_config, setup_logging
from fairchem.core.modules.scaling.fit import compute_scaling_factors
Expand Down Expand Up @@ -159,6 +163,115 @@ def test_gemnet_fit_scaling(self, configs, tutorial_val_src):
input_yaml=configs["gemnet_oc"],
)

def test_convert_checkpoint_and_config_to_hydra(self, configs, tutorial_val_src):
with tempfile.TemporaryDirectory() as tempdirname:
# first train a very simple model then checkpoint
train_rundir = Path(tempdirname) / "train"
train_rundir.mkdir()
checkpoint_path = str(train_rundir / "checkpoint.pt")
acc = _run_main(
rundir=str(train_rundir),
input_yaml=configs["equiformer_v2"],
update_dict_with={
"optim": {
"max_epochs": 2,
"eval_every": 8,
"batch_size": 5,
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
otf_norms=False,
),
},
save_checkpoint_to=checkpoint_path,
)

# load the checkpoint and predict
predictions_rundir = Path(tempdirname) / "predict"
predictions_rundir.mkdir()
predictions_filename = str(predictions_rundir / "predictions.npz")
_run_main(
rundir=str(predictions_rundir),
input_yaml=configs["equiformer_v2"],
update_dict_with={
"optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
otf_norms=False,
),
},
update_run_args_with={
"mode": "predict",
"checkpoint": checkpoint_path,
},
save_predictions_to=predictions_filename,
)

# convert the checkpoint to hydra
hydra_yaml = Path(tempdirname) / "hydra.yml"
hydra_checkpoint = Path(tempdirname) / "hydra.pt"

convert_checkpoint_and_config_to_hydra(
yaml_fn=configs["equiformer_v2"],
checkpoint_fn=checkpoint_path,
new_yaml_fn=hydra_yaml,
new_checkpoint_fn=hydra_checkpoint,
)

# load hydra checkpoint and predict
hydra_predictions_rundir = Path(tempdirname) / "hydra_predict"
hydra_predictions_rundir.mkdir()
hydra_predictions_filename = str(predictions_rundir / "predictions.npz")
_run_main(
rundir=str(hydra_predictions_rundir),
input_yaml=hydra_yaml,
update_dict_with={
"optim": {"max_epochs": 2, "eval_every": 8, "batch_size": 5},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
otf_norms=False,
),
},
update_run_args_with={
"mode": "predict",
"checkpoint": hydra_checkpoint,
},
save_predictions_to=hydra_predictions_filename,
)

# verify predictions from eqv2 and hydra eqv2 are same
energy_from_checkpoint = np.load(predictions_filename)["energy"]
energy_from_hydra_checkpoint = np.load(hydra_predictions_filename)["energy"]
npt.assert_allclose(
energy_from_hydra_checkpoint,
energy_from_checkpoint,
rtol=1e-6,
atol=1e-6,
)
forces_from_checkpoint = np.load(predictions_filename)["forces"]
forces_from_hydra_checkpoint = np.load(hydra_predictions_filename)["forces"]
npt.assert_allclose(
forces_from_hydra_checkpoint,
forces_from_checkpoint,
rtol=1e-6,
atol=1e-6,
)

# should not let you overwrite the files if they exist
with pytest.raises(AssertionError):
convert_checkpoint_and_config_to_hydra(
yaml_fn=configs["equiformer_v2"],
checkpoint_fn=checkpoint_path,
new_yaml_fn=hydra_yaml,
new_checkpoint_fn=hydra_checkpoint,
)

# not all models are tested with otf normalization estimation
# only gemnet_oc, escn, equiformer, and their hydra versions
@pytest.mark.parametrize(
Expand Down

0 comments on commit 7d40b20

Please sign in to comment.