Skip to content

Commit

Permalink
support qwen dpo fused kernel (#9686)
Browse files Browse the repository at this point in the history
  • Loading branch information
wtmlon authored Dec 30, 2024
1 parent 79695cc commit 07ad5e6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
29 changes: 10 additions & 19 deletions llm/alignment/dpo/dpo_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from paddlenlp.trainer import TrainingArguments
from paddlenlp.trainer.trainer_utils import IntervalStrategy
from paddlenlp.trainer.utils.doc import add_start_docstrings
from paddlenlp.transformers.configuration_utils import llmmetaclass


@dataclass
@llmmetaclass
@add_start_docstrings(TrainingArguments.__doc__)
class DPOTrainingArguments(TrainingArguments):
"""DPOTrainingArguments"""
Expand Down Expand Up @@ -122,30 +124,19 @@ class DPOModelArgument:
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
)
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
recompute_granularity: str = field(
default="full",
metadata={
"help": "The granularity of recompute training can be selected as `full` or `full_attn` or `core_attn`."
},
)
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash mask in flash attention."})
virtual_pp_degree: int = field(
default=1,
metadata={"help": "virtual_pp_degree"},
)
sequence_parallel: bool = field(
default=False,
metadata={"help": "whether to use sequence parallel"},
)
tensor_parallel_output: bool = field(
default=True,
metadata={"help": "whether to use tensor_parallel_output"},
)
weight_quantize_algo: str = field(
default=None,
metadata={"help": "Model weight quantization algorithm including 'nf4'(qlora), 'weight_only_int8'."},
)
fuse_attention_qkv: bool = field(
default=None,
metadata={"help": "whether to fuse attention qkv"},
)
fuse_attention_ffn: bool = field(
default=None,
metadata={"help": "whether to fuse first up and gate proj in mlp block"},
)
# LoRA
lora_rank: int = field(default=8, metadata={"help": "Lora rank."})
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
Expand Down
12 changes: 8 additions & 4 deletions llm/alignment/dpo/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,13 @@ def main():
dtype=dtype,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=training_args.tensor_parallel_rank,
recompute_granularity=model_args.recompute_granularity,
use_flash_attention=model_args.use_flash_attention,
tensor_parallel_output=model_args.tensor_parallel_output,
recompute_granularity=training_args.recompute_granularity,
use_flash_attention=training_args.use_flash_attention,
tensor_parallel_output=training_args.tensor_parallel_output,
use_fused_rms_norm=training_args.use_fused_rms_norm,
use_fused_rope=training_args.use_fused_rope,
use_fused_linear=training_args.use_fused_linear,
use_fused_dropout_add=training_args.use_fused_dropout_add,
)

if training_args.pipeline_parallel_degree > 1:
Expand Down Expand Up @@ -157,7 +161,7 @@ def main():
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
raise NotImplementedError(f"{model.__class__} not support flash mask.")

if model_args.sequence_parallel:
if training_args.sequence_parallel:
register_sequence_parallel_allreduce_hooks(
model, training_args.gradient_accumulation_steps, model_args.fuse_sequence_parallel_allreduce
)
Expand Down

0 comments on commit 07ad5e6

Please sign in to comment.