diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 57c655736f25..d11ed0495ce9 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2530,7 +2530,10 @@ def _save_checkpoint(self, model, metrics=None): global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) - if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + if ( + "skip_save_model_weight" not in self.args.unified_checkpoint_config + or "remove_master_weight" not in self.args.unified_checkpoint_config + ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) if self.args.should_save or self.args.use_expert_parallel: if not self.args.use_hybrid_parallel: @@ -2567,7 +2570,10 @@ def _save_checkpoint(self, model, metrics=None): global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 os.makedirs(signal_dir, exist_ok=True) paddle.save(global_rank, os.path.join(signal_dir, f".optimizer_weight.done.{global_rank}")) - if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + if ( + "skip_save_model_weight" not in self.args.unified_checkpoint_config + or "remove_master_weight" not in self.args.unified_checkpoint_config + ): paddle.save(global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")) self.runtime_timer.stop()