Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wandb logger init to hydra runners #894

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING

import hydra
from omegaconf import OmegaConf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is currently a dependency, should we add this to pyproject.toml?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its installed through hydra


if TYPE_CHECKING:
import argparse
Expand All @@ -32,29 +33,48 @@


class Submitit(Checkpointable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity, since the Runner class is already a Checkpointable, do we need both the Runner and Submitit to inherit from Checkpointable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Runner is not a Checkpointable?

def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None:
def __call__(self, dict_config: DictConfig) -> None:
self.config = dict_config
self.cli_args = cli_args
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_env_vars()
try:
distutils.setup(map_cli_args_to_dist_config(cli_args))
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.load_state()
runner.run()
finally:
distutils.cleanup()

def checkpoint(self, *args, **kwargs):
distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args))
self._init_logger()
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.load_state()
runner.run()
distutils.cleanup()

def _init_logger(self) -> None:
# optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger
# don't start logger if in debug mode
if (
"logger" in self.config
and distutils.is_master()
and not self.config.cli_args.debug
):
# get a partial function from the config and instantiate wandb with it
logger_initializer = hydra.utils.instantiate(self.config.logger)
simple_config = OmegaConf.to_container(
self.config, resolve=True, throw_on_missing=True
)
logger_initializer(
config=simple_config,
run_id=self.config.cli_args.timestamp_id,
run_name=self.config.cli_args.identifier,
log_dir=self.config.cli_args.logdir,
)

def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
# TODO: this is yet to be tested properly
logging.info("Submitit checkpointing callback is triggered")
new_runner = Runner()
new_runner.save_state()
logging.info("Submitit checkpointing callback is completed")
return DelayedSubmission(new_runner, self.config)


def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict:
def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict:
return {
"world_size": cli_args.num_nodes * cli_args.num_gpus,
"distributed_backend": "gloo" if cli_args.cpu else "nccl",
Expand All @@ -76,8 +96,8 @@ def get_hydra_config_from_yaml(
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace):
Submitit()(config, cli_args)
def runner_wrapper(config: DictConfig):
Submitit()(config)


# this is meant as a future replacement for the main entrypoint
Expand All @@ -91,6 +111,11 @@ def main(
cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
timestamp_id = get_timestamp_uid()
log_dir = os.path.join(args.run_dir, timestamp_id, "logs")
# override timestamp id and logdir
args.timestamp_id = timestamp_id
args.logdir = log_dir
os.makedirs(log_dir)
OmegaConf.update(cfg, "cli_args", vars(args), force_add=True)
if args.submit: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
Expand All @@ -105,7 +130,7 @@ def main(
slurm_qos=args.slurm_qos,
slurm_account=args.slurm_account,
)
job = executor.submit(runner_wrapper, cfg, args)
job = executor.submit(runner_wrapper, cfg)
logger.info(
f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}"
)
Expand All @@ -119,8 +144,8 @@ def main(
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg, args)
elastic_launch(launch_config, runner_wrapper)(cfg)
else:
logger.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg, args)
runner_wrapper(cfg)