Skip to content

Commit

Permalink
support both checkpoint styles
Browse files Browse the repository at this point in the history
  • Loading branch information
rgao committed Dec 10, 2024
1 parent 9ff9a3d commit d583250
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
10 changes: 10 additions & 0 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,3 +1478,13 @@ def get_weight_table(model: torch.nn.Module) -> tuple[list, list]:
row_grad = [None] * len(row_weight)
data.append([param_name] + [params.shape] + row_weight + row_grad) # noqa
return columns, data


def get_checkpoint_format(config: dict) -> str:
# a temporary function to retrieve the checkpoint format from old configs
format = config.get("optim", {}).get("checkpoint_format", "pt")
assert format in (
"pt",
"dcp",
), f"checkpoint format can only be pt or dcp, found {format}"
return format
15 changes: 12 additions & 3 deletions src/fairchem/core/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import get_checkpoint_format
from fairchem.core.trainers import OCPTrainer


Expand All @@ -21,8 +22,13 @@ def __init__(self, config) -> None:
def setup(self, trainer) -> None:
self.trainer = trainer

# TODO: make checkpoint.pt a constant so we don't pass this string around everywhere
self.chkpt_path = os.path.join(self.trainer.config["cmd"]["checkpoint_dir"])
format = get_checkpoint_format(self.config)
if format == "pt":
self.chkpt_path = os.path.join(
self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt"
)
else:
self.chkpt_path = self.trainer.config["cmd"]["checkpoint_dir"]

# if the supplied checkpoint exists, then load that, ie: when user specifies the --checkpoint option
# OR if the a job was preempted correctly and the submitit checkpoint function was called
Expand All @@ -36,7 +42,10 @@ def setup(self, trainer) -> None:
# if the supplied checkpoint doesn't exist and there exists a previous checkpoint in the checkpoint path, this
# means that the previous job didn't terminate "nicely" (due to node failures, crashes etc), then attempt
# to load the last found checkpoint
elif len(os.listdir(self.chkpt_path)) > 0:
elif (
os.path.isfile(self.chkpt_path)
or (os.path.isdir(self.chkpt_path) and len(os.listdir(self.chkpt_path))) > 0
):
logging.info(
f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint"
)
Expand Down

0 comments on commit d583250

Please sign in to comment.