Skip to content

Commit

Permalink
[LLM] fix pdc download ckpt (#9690)
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII authored Dec 24, 2024
1 parent 7ab494a commit f08082e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
9 changes: 9 additions & 0 deletions paddlenlp/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.io import IterableDataset
Expand Down Expand Up @@ -1091,7 +1092,15 @@ def download_recovery_ckpt_from_pdc(recovery_checkpoint_path, timeout):
except Exception as e:
raise RuntimeError(f"{PDC_DOWNLOAD_ERROR}; Failed to parse checkpoint path, details: {e}")
start_time = time.time()
# TODO(@gexiao): temporary workaround for environment variable conflicts.
original_trainer_id = os.getenv("PADDLE_TRAINER_ID")
original_trainers_num = os.getenv("PADDLE_TRAINERS_NUM")
cards_per_node = int(os.getenv("PADDLE_LOCAL_SIZE", "8"))
os.environ["PADDLE_TRAINER_ID"] = str(dist.get_rank() // cards_per_node)
os.environ["PADDLE_TRAINERS_NUM"] = str(dist.get_world_size() // cards_per_node)
result = pdc_tool.pdc_download_checkpoint(download_step, timeout)
os.environ["PADDLE_TRAINER_ID"] = original_trainer_id
os.environ["PADDLE_TRAINERS_NUM"] = original_trainers_num
end_time = time.time()
if result == PDCErrorCode.Success:
logger.info(f"Successfully downloaded checkpoint from PDC, total time cost: {end_time - start_time} seconds.")
Expand Down
4 changes: 3 additions & 1 deletion paddlenlp/utils/pdc_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def _exec_cmd(self, cmd_args: List[str]) -> (str, PDCErrorCode):
error_code = PDCErrorCode.Success
try:
result = subprocess.run(cmd_args, capture_output=True, text=True)
if result.returncode != 0:
if result.returncode == 0:
logger.info(f"exec cmd {cmd_args} successfully, result: {result.stdout}; {result.stderr}")
else:
logger.error(f"exec cmd {cmd_args} failed, exit code: {result.returncode}, err: {result.stderr}")
# TODO(@zezhao): add more error code
error_code = PDCErrorCode.CommandFail
Expand Down

0 comments on commit f08082e

Please sign in to comment.