From a53b294b47884172c90ea07aa137c078a3c289ae Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Thu, 8 Aug 2024 10:19:27 +0000 Subject: [PATCH] diloco ckpt check fix --- open_diloco/train_fsdp.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 33b70e6..d04e67e 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -70,8 +70,13 @@ def log(message): logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}") -def check_checkpoint_path_access(checkpoint_path: str, rank: int): - dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt") +def check_checkpoint_path_access(checkpoint_path: str, rank: int, world_rank_hv: int | None = None): + if world_rank_hv: + dummy_file_path = os.path.join( + checkpoint_path, get_diloco_rank_dir_name(world_rank_hv), f"dummy_file_{rank}.txt" + ) + else: + dummy_file_path = os.path.join(checkpoint_path, f"dummy_file_{rank}.txt") with fsspec.open(dummy_file_path, "w") as f: f.write("This is a dummy file for testing access.") gfs = GenericFileSystem() @@ -221,7 +226,7 @@ def train(config: Config): log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=False) if local_rank == 0: - check_checkpoint_path_access(config.checkpoint_path, rank) + check_checkpoint_path_access(config.checkpoint_path, rank, config.hv.world_rank if config.hv else None) # DataLoader preparation tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)