-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add wandb logger init to hydra runners #894
base: main
Are you sure you want to change the base?
Changes from all commits
2a231ef
568a8e2
6c8ecd0
ffdea38
0c3fb34
d534bda
d56ba1e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
from typing import TYPE_CHECKING | ||
|
||
import hydra | ||
from omegaconf import OmegaConf | ||
|
||
if TYPE_CHECKING: | ||
import argparse | ||
|
@@ -32,29 +33,48 @@ | |
|
||
|
||
class Submitit(Checkpointable): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just out of curiosity, since the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Runner is not a Checkpointable? |
||
def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None: | ||
def __call__(self, dict_config: DictConfig) -> None: | ||
self.config = dict_config | ||
self.cli_args = cli_args | ||
# TODO: setup_imports is not needed if we stop instantiating models with Registry. | ||
setup_imports() | ||
setup_env_vars() | ||
try: | ||
distutils.setup(map_cli_args_to_dist_config(cli_args)) | ||
runner: Runner = hydra.utils.instantiate(dict_config.runner) | ||
runner.load_state() | ||
runner.run() | ||
finally: | ||
distutils.cleanup() | ||
|
||
def checkpoint(self, *args, **kwargs): | ||
distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) | ||
self._init_logger() | ||
runner: Runner = hydra.utils.instantiate(dict_config.runner) | ||
runner.load_state() | ||
runner.run() | ||
distutils.cleanup() | ||
|
||
def _init_logger(self) -> None: | ||
# optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger | ||
# don't start logger if in debug mode | ||
if ( | ||
"logger" in self.config | ||
and distutils.is_master() | ||
and not self.config.cli_args.debug | ||
): | ||
# get a partial function from the config and instantiate wandb with it | ||
logger_initializer = hydra.utils.instantiate(self.config.logger) | ||
simple_config = OmegaConf.to_container( | ||
self.config, resolve=True, throw_on_missing=True | ||
) | ||
logger_initializer( | ||
config=simple_config, | ||
run_id=self.config.cli_args.timestamp_id, | ||
run_name=self.config.cli_args.identifier, | ||
log_dir=self.config.cli_args.logdir, | ||
) | ||
|
||
def checkpoint(self, *args, **kwargs) -> DelayedSubmission: | ||
# TODO: this is yet to be tested properly | ||
logging.info("Submitit checkpointing callback is triggered") | ||
new_runner = Runner() | ||
new_runner.save_state() | ||
logging.info("Submitit checkpointing callback is completed") | ||
return DelayedSubmission(new_runner, self.config) | ||
|
||
|
||
def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: | ||
def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict: | ||
return { | ||
"world_size": cli_args.num_nodes * cli_args.num_gpus, | ||
"distributed_backend": "gloo" if cli_args.cpu else "nccl", | ||
|
@@ -76,8 +96,8 @@ def get_hydra_config_from_yaml( | |
return hydra.compose(config_name=config_name, overrides=overrides_args) | ||
|
||
|
||
def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace): | ||
Submitit()(config, cli_args) | ||
def runner_wrapper(config: DictConfig): | ||
Submitit()(config) | ||
|
||
|
||
# this is meant as a future replacement for the main entrypoint | ||
|
@@ -91,6 +111,11 @@ def main( | |
cfg = get_hydra_config_from_yaml(args.config_yml, override_args) | ||
timestamp_id = get_timestamp_uid() | ||
log_dir = os.path.join(args.run_dir, timestamp_id, "logs") | ||
# override timestamp id and logdir | ||
args.timestamp_id = timestamp_id | ||
args.logdir = log_dir | ||
os.makedirs(log_dir) | ||
OmegaConf.update(cfg, "cli_args", vars(args), force_add=True) | ||
if args.submit: # Run on cluster | ||
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) | ||
executor.update_parameters( | ||
|
@@ -105,7 +130,7 @@ def main( | |
slurm_qos=args.slurm_qos, | ||
slurm_account=args.slurm_account, | ||
) | ||
job = executor.submit(runner_wrapper, cfg, args) | ||
job = executor.submit(runner_wrapper, cfg) | ||
logger.info( | ||
f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}" | ||
) | ||
|
@@ -119,8 +144,8 @@ def main( | |
rdzv_backend="c10d", | ||
max_restarts=0, | ||
) | ||
elastic_launch(launch_config, runner_wrapper)(cfg, args) | ||
elastic_launch(launch_config, runner_wrapper)(cfg) | ||
else: | ||
logger.info("Running in local mode without elastic launch") | ||
distutils.setup_env_local() | ||
runner_wrapper(cfg, args) | ||
runner_wrapper(cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dont think this is currently a dependency, should we add this to pyproject.toml?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its installed through hydra