Skip to content

Commit

Permalink
Merge branch 'main' into stress-relaxations
Browse files Browse the repository at this point in the history
  • Loading branch information
lbluque authored Aug 22, 2024
2 parents 4a51006 + 94e4a7f commit 2e118ea
Show file tree
Hide file tree
Showing 17 changed files with 1,255 additions and 556 deletions.
3 changes: 2 additions & 1 deletion src/fairchem/core/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from __future__ import annotations

import argparse
import os
from pathlib import Path


Expand Down Expand Up @@ -48,7 +49,7 @@ def add_core_args(self) -> None:
)
self.parser.add_argument(
"--run-dir",
default="./",
default=os.path.abspath("./"),
type=str,
help="Directory to store checkpoint/log/result directory",
)
Expand Down
149 changes: 106 additions & 43 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def save_checkpoint(
return filename


multitask_required_keys = {
"tasks",
"datasets",
"combined_dataset",
"model",
"optim",
}


class Complete:
def __call__(self, data):
device = data.edge_index.device
Expand Down Expand Up @@ -393,48 +402,83 @@ def create_dict_from_args(args: list, sep: str = "."):
return return_dict


def load_config(path: str, previous_includes: list | None = None):
if previous_includes is None:
previous_includes = []
# given a filename and set of paths , return the full file path
def find_relative_file_in_paths(filename, include_paths):
if os.path.exists(filename):
return filename
for path in include_paths:
include_filename = os.path.join(path, filename)
if os.path.exists(include_filename):
return include_filename
raise ValueError(f"Cannot find include YML {filename}")


def load_config(
path: str,
files_previously_included: list | None = None,
include_paths: list | None = None,
):
"""
Load a given config with any defined imports
When imports are present this is a recursive function called on imports.
To prevent any cyclic imports we keep track of already imported yml files
using files_previously_included
"""
if include_paths is None:
include_paths = []
if files_previously_included is None:
files_previously_included = []
path = Path(path)
if path in previous_includes:
if path in files_previously_included:
raise ValueError(
f"Cyclic config include detected. {path} included in sequence {previous_includes}."
f"Cyclic config include detected. {path} included in sequence {files_previously_included}."
)
previous_includes = [*previous_includes, path]
files_previously_included = [*files_previously_included, path]

with open(path) as fp:
direct_config = yaml.load(fp, Loader=UniqueKeyLoader)
current_config = yaml.load(fp, Loader=UniqueKeyLoader)

# Load config from included files.
includes = direct_config.pop("includes") if "includes" in direct_config else []
if not isinstance(includes, list):
raise AttributeError(f"Includes must be a list, '{type(includes)}' provided")
includes_listed_in_config = (
current_config.pop("includes") if "includes" in current_config else []
)
if not isinstance(includes_listed_in_config, list):
raise AttributeError(
f"Includes must be a list, '{type(includes_listed_in_config)}' provided"
)

config = {}
config_from_includes = {}
duplicates_warning = []
duplicates_error = []

for include in includes:
for include in includes_listed_in_config:
include_filename = find_relative_file_in_paths(
include, [os.path.dirname(path), *include_paths]
)
include_config, inc_dup_warning, inc_dup_error = load_config(
include, previous_includes
include_filename, files_previously_included
)
duplicates_warning += inc_dup_warning
duplicates_error += inc_dup_error

# Duplicates between includes causes an error
config, merge_dup_error = merge_dicts(config, include_config)
config_from_includes, merge_dup_error = merge_dicts(
config_from_includes, include_config
)
duplicates_error += merge_dup_error

# Duplicates between included and main file causes warnings
config, merge_dup_warning = merge_dicts(config, direct_config)
config_from_includes, merge_dup_warning = merge_dicts(
config_from_includes, current_config
)
duplicates_warning += merge_dup_warning
return config_from_includes, duplicates_warning, duplicates_error

return config, duplicates_warning, duplicates_error


def build_config(args, args_override):
config, duplicates_warning, duplicates_error = load_config(args.config_yml)
def build_config(args, args_override, include_paths=None):
config, duplicates_warning, duplicates_error = load_config(
args.config_yml, include_paths=include_paths
)
if len(duplicates_warning) > 0:
logging.warning(
f"Overwritten config parameters from included configs "
Expand Down Expand Up @@ -999,34 +1043,53 @@ class _TrainingContext:
task_name = "s2ef"
elif trainer_name in ["energy", "equiformerv2_energy"]:
task_name = "is2re"
elif "multitask" in trainer_name:
task_name = "multitask"
else:
task_name = "ocp"

trainer_cls = registry.get_trainer_class(trainer_name)
assert trainer_cls is not None, "Trainer not found"
trainer = trainer_cls(
task=config.get("task", {}),
model=config["model"],
outputs=config.get("outputs", {}),
dataset=config["dataset"],
optimizer=config["optim"],
loss_functions=config.get("loss_functions", {}),
evaluation_metrics=config.get("evaluation_metrics", {}),
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "wandb"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
slurm=config.get("slurm", {}),
noddp=config.get("noddp", False),
name=task_name,
gp_gpus=config.get("gp_gpus"),
)

trainer_config = {
"model": config["model"],
"optimizer": config["optim"],
"identifier": config["identifier"],
"timestamp_id": config.get("timestamp_id", None),
"run_dir": config.get("run_dir", "./"),
"is_debug": config.get("is_debug", False),
"print_every": config.get("print_every", 10),
"seed": config.get("seed", 0),
"logger": config.get("logger", "wandb"),
"local_rank": config["local_rank"],
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"noddp": config.get("noddp", False),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}

if task_name == "multitask":
trainer_config.update(
{
"tasks": config.get("tasks", {}),
"dataset_configs": config["datasets"],
"combined_dataset_config": config.get("combined_dataset", {}),
"evaluations": config.get("evaluations", {}),
}
)
else:
trainer_config.update(
{
"task": config.get("task", {}),
"outputs": config.get("outputs", {}),
"dataset": config["dataset"],
"loss_functions": config.get("loss_functions", {}),
"evaluation_metrics": config.get("evaluation_metrics", {}),
}
)
trainer = trainer_cls(**trainer_config)

task_cls = registry.get_task_class(config["mode"])
assert task_cls is not None, "Task not found"
Expand Down
21 changes: 17 additions & 4 deletions src/fairchem/core/datasets/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
from torch_geometric.data import Data


def rename_data_object_keys(data_object: Data, key_mapping: dict[str, str]) -> Data:
def rename_data_object_keys(
data_object: Data, key_mapping: dict[str, str | list[str]]
) -> Data:
"""Rename data object keys
Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
new_key can be a list of new keys, for example,
prev_key: energy
new_key: [common_energy, oc20_energy]
This is currently required when we use a single target/label for multiple tasks
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = key_mapping[_property]
if new_property not in data_object:
list_of_new_keys = key_mapping[_property]
if isinstance(list_of_new_keys, str):
list_of_new_keys = [list_of_new_keys]
for new_property in list_of_new_keys:
if new_property == _property:
continue
assert new_property not in data_object
data_object[new_property] = data_object[_property]
if _property not in list_of_new_keys:
del data_object[_property]

return data_object
2 changes: 1 addition & 1 deletion src/fairchem/core/models/equiformer_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations

from .equiformer_v2 import EquiformerV2
from .equiformer_v2_deprecated import EquiformerV2

__all__ = ["EquiformerV2"]
Loading

0 comments on commit 2e118ea

Please sign in to comment.