From adf0335e2921ea6fc013429b67605249bab50103 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Tue, 20 Jun 2023 15:24:27 +0800 Subject: [PATCH] fixed load peft weight. https://github.com/shibing624/textgen/issues/47 --- textgen/chatglm/chatglm_model.py | 2 +- textgen/gpt/gpt_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/textgen/chatglm/chatglm_model.py b/textgen/chatglm/chatglm_model.py index e4a07df..5f4dcee 100644 --- a/textgen/chatglm/chatglm_model.py +++ b/textgen/chatglm/chatglm_model.py @@ -327,7 +327,7 @@ def train_model( # The two files above have a different name depending on how they were saved, but are actually the same. if os.path.exists(checkpoint_name): logger.info(f"Restarting from {checkpoint_name}") - adapters_weights = torch.load(checkpoint_name) + adapters_weights = torch.load(checkpoint_name, map_location='cpu') set_peft_model_state_dict(self.model, adapters_weights) else: logger.warning(f"Checkpoint {checkpoint_name} not found") diff --git a/textgen/gpt/gpt_model.py b/textgen/gpt/gpt_model.py index 80f08ad..e573b8d 100644 --- a/textgen/gpt/gpt_model.py +++ b/textgen/gpt/gpt_model.py @@ -356,7 +356,7 @@ def train_model( # The two files above have a different name depending on how they were saved, but are actually the same. if os.path.exists(checkpoint_name): logger.info(f"Restarting from {checkpoint_name}") - adapters_weights = torch.load(checkpoint_name) + adapters_weights = torch.load(checkpoint_name, map_location='cpu') set_peft_model_state_dict(self.model, adapters_weights) else: logger.warning(f"Checkpoint {checkpoint_name} not found")