From cd5468ca66577fd70a4800355bdf3b5aca8ecc99 Mon Sep 17 00:00:00 2001 From: wawltor Date: Wed, 11 Dec 2024 22:28:08 +0800 Subject: [PATCH] fix the d2s bug in qwen2 modeling (#9603) * fix the d2s bug in qwen2 modeling * update the code for the predictor --- llm/predict/predictor.py | 5 ++++- paddlenlp/transformers/qwen2/modeling.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index aaf206e13cc3..0d358398956c 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -27,7 +27,10 @@ from paddle.base.framework import in_cinn_mode, in_pir_executor_mode, use_pir_api from paddle.distributed import fleet -from paddlenlp.experimental.transformers import InferenceWithReferenceProposer +try: + from paddlenlp.experimental.transformers import InferenceWithReferenceProposer +except: + pass from paddlenlp.generation import GenerationConfig, TextIteratorStreamer from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM from paddlenlp.taskflow.utils import static_mode_guard diff --git a/paddlenlp/transformers/qwen2/modeling.py b/paddlenlp/transformers/qwen2/modeling.py index 9b1ab534cc0e..35b3ea91f2b5 100644 --- a/paddlenlp/transformers/qwen2/modeling.py +++ b/paddlenlp/transformers/qwen2/modeling.py @@ -1026,7 +1026,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) else: - expanded_attn_mask = paddle.where(expanded_attn_mask.to("bool"), 0.0, paddle.finfo(dtype).min).astype( + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( dtype ) return expanded_attn_mask