Skip to content

Commit

Permalink
[LLM] enhance paddle recall for flash checkpoint (#9684)
Browse files Browse the repository at this point in the history
  • Loading branch information
SylarTiaNII authored Dec 24, 2024
1 parent b823d70 commit 7ab494a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 3 additions & 3 deletions paddlenlp/trainer/utils/flash_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
TRAINER_STATE_NAME,
TRAINING_ARGS_NAME,
)
from paddlenlp.utils.fault_tolerance import FC_DUMP_ERROR
from paddlenlp.utils.fault_tolerance import FC_DUMP_ERROR, PC_DUMP_ERROR
from paddlenlp.utils.log import logger


Expand Down Expand Up @@ -300,7 +300,7 @@ def report_error_worker(self):
for worker in self.workers:
if worker.status.value == FCWorkerStatus.ERROR.value:
logger.error(f"[FC manager] Worker{worker.worker_id} encountered error.")
raise RuntimeError(f"{FC_DUMP_ERROR}")
raise RuntimeError(f"{PC_DUMP_ERROR}")

def flash_checkpoint_pipeline_hook(self, hook_id):
if self.current_worker is None:
Expand Down Expand Up @@ -450,7 +450,7 @@ def process_dump_task(self):
self.process_dump_task_impl(self.flash_save_dir)
logger.info(f"[FC worker{self.worker_id}] Dumping to flash device done: {self.flash_save_dir}")
except Exception as e:
logger.error(f"[FC worker{self.worker_id}] Failed to dump to flash device: {e}")
logger.error(f"{FC_DUMP_ERROR} [FC worker{self.worker_id}] Failed to dump to flash device: {e}")
if self.persistent_save_dir:
try:
self.process_dump_task_impl(self.persistent_save_dir)
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/utils/fault_tolerance.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
LOSS_INF_ERROR = "PaddleRecall error(104): LossInf"

PDC_DOWNLOAD_ERROR = "PaddleRecall error(105): PDCDownloadError"
FC_DUMP_ERROR = "PaddleRecall error(106): FCDumpError"
# only warn msg
FC_DUMP_ERROR = "PaddleRecall error(106): FlashCheckpointDumpError"
# fatal error, must be fixed by babysitters
PC_DUMP_ERROR = "PaddleRecall error(107): PersistentCheckpointDumpError"


def is_ft_env():
Expand Down

0 comments on commit 7ab494a

Please sign in to comment.