diff --git a/csrc/gpu/append_attention.cu b/csrc/gpu/append_attention.cu index f80f8cee5d3d..d24a20e48d11 100644 --- a/csrc/gpu/append_attention.cu +++ b/csrc/gpu/append_attention.cu @@ -59,10 +59,6 @@ std::vector AppendAttentionKernel( const float quant_max_bound, const float quant_min_bound, const float out_linear_in_scale, - const int encoder_block_shape_q, - const int decoder_block_shape_q, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool speculate_decoder) { @@ -76,7 +72,8 @@ std::vector AppendAttentionKernel( int max_enc_len_this_time_data = max_enc_len_this_time.data()[0]; int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; int max_len_kv_data = max_len_kv.data()[0]; - + const int encoder_block_shape_q = get_encoder_block_shape_q(); + const int decoder_block_shape_q = get_decoder_block_shape_q(); auto main_stream = qkv.stream(); static cudaEvent_t main_event; static cudaEvent_t decoder_event; @@ -209,8 +206,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, false, @@ -248,8 +243,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, false, @@ -292,8 +285,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, false, @@ -440,8 +431,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, !speculate_decoder, @@ -479,8 +468,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, !speculate_decoder, @@ -524,8 +511,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, !speculate_decoder, @@ -583,10 +568,6 @@ std::vector AppendAttention( const float quant_max_bound, const float quant_min_bound, const float out_linear_in_scale, - const int encoder_block_shape_q, - const int decoder_block_shape_q, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool speculate_decoder) { @@ -648,10 +629,6 @@ std::vector AppendAttention( quant_max_bound, quant_min_bound, out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, speculate_decoder); @@ -698,10 +675,6 @@ std::vector AppendAttention( quant_max_bound, quant_min_bound, out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, speculate_decoder); @@ -749,10 +722,6 @@ std::vector AppendAttention( quant_max_bound, quant_min_bound, out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, speculate_decoder); @@ -798,10 +767,6 @@ std::vector AppendAttention( quant_max_bound, quant_min_bound, out_linear_in_scale, - encoder_block_shape_q, - decoder_block_shape_q, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, speculate_decoder); @@ -903,10 +868,6 @@ std::vector AppendAttentionInferDtype( const float quant_max_bound, const float quant_min_bound, const float out_linear_in_scale, - const int encoder_block_shape_q, - const int decoder_block_shape_q, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool speculate_decoder) { @@ -983,10 +944,6 @@ PD_BUILD_OP(append_attention) "quant_max_bound: float", "quant_min_bound: float", "out_linear_in_scale: float", - "encoder_block_shape_q: int", - "decoder_block_shape_q: int", - "max_partition_size: int", - "encoder_max_partition_size: int", "speculate_max_draft_token_num: int", "causal: bool", "speculate_decoder: bool"}) diff --git a/csrc/gpu/append_attn/append_attention_c16_impl.cuh b/csrc/gpu/append_attn/append_attention_c16_impl.cuh index ed181836d73c..78526f703a3f 100644 --- a/csrc/gpu/append_attn/append_attention_c16_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c16_impl.cuh @@ -786,8 +786,6 @@ void MultiQueryAppendAttention( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool is_decoder, cudaStream_t &stream, @@ -839,9 +837,9 @@ void MultiQueryAppendAttention( int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); + chunk_size = get_encoder_max_partition_size(); } const int num_chunks = div_up(max_dec_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); @@ -1058,9 +1056,9 @@ void MultiQueryAppendAttention( int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); + chunk_size = get_encoder_max_partition_size(); } const int num_chunks = div_up(max_dec_len, chunk_size); @@ -1301,8 +1299,6 @@ void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -1363,8 +1359,6 @@ void CascadeAppendAttentionC16Kernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, is_decoder, stream, diff --git a/csrc/gpu/append_attn/append_attention_c4_impl.cuh b/csrc/gpu/append_attn/append_attention_c4_impl.cuh index 586bde4dc741..9fd2eaebb841 100644 --- a/csrc/gpu/append_attn/append_attention_c4_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c4_impl.cuh @@ -973,8 +973,6 @@ void MultiQueryAppendC4Attention( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool is_decoder, cudaStream_t &stream, @@ -1036,9 +1034,9 @@ void MultiQueryAppendC4Attention( const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); - uint32_t chunk_size = static_cast(max_partition_size); + uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); + chunk_size = get_encoder_max_partition_size(); } const int num_chunks = div_up(max_dec_len, chunk_size); @@ -1282,9 +1280,9 @@ void MultiQueryAppendC4Attention( static_cast(num_blocks_per_wave); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); + chunk_size = get_encoder_max_partition_size(); } const int num_chunks = div_up(max_dec_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); @@ -1538,8 +1536,6 @@ void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -1604,8 +1600,6 @@ void CascadeAppendAttentionC4Kernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, is_decoder, stream, diff --git a/csrc/gpu/append_attn/append_attention_c8_impl.cuh b/csrc/gpu/append_attn/append_attention_c8_impl.cuh index d5d1cc38e1b4..efe8f40a17c3 100644 --- a/csrc/gpu/append_attn/append_attention_c8_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c8_impl.cuh @@ -860,8 +860,6 @@ void MultiQueryAppendC8Attention( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool is_decoder, cudaStream_t &stream, @@ -914,9 +912,9 @@ void MultiQueryAppendC8Attention( const int dev_id = 0; int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); + chunk_size = get_encoder_max_partition_size(); } const int num_chunks = div_up(max_dec_len, chunk_size); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); @@ -1136,9 +1134,9 @@ void MultiQueryAppendC8Attention( const int dev_id = 0; int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { - chunk_size = static_cast(encoder_max_partition_size); + chunk_size = get_encoder_max_partition_size(); } const int num_chunks = div_up(max_dec_len, chunk_size); @@ -1377,8 +1375,6 @@ void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -1441,8 +1437,6 @@ void CascadeAppendAttentionC8Kernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, is_decoder, stream, diff --git a/csrc/gpu/append_attn/append_attention_kernel.h b/csrc/gpu/append_attn/append_attention_kernel.h index 59532b2400c5..f740057e9321 100644 --- a/csrc/gpu/append_attn/append_attention_kernel.h +++ b/csrc/gpu/append_attn/append_attention_kernel.h @@ -52,8 +52,6 @@ void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -97,8 +95,6 @@ void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -142,8 +138,6 @@ void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -188,8 +182,6 @@ void CascadeAppendAttentionKernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, @@ -223,8 +215,6 @@ void CascadeAppendAttentionKernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, is_decoder, @@ -258,8 +248,6 @@ void CascadeAppendAttentionKernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, is_decoder, @@ -293,8 +281,6 @@ void CascadeAppendAttentionKernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, - encoder_max_partition_size, speculate_max_draft_token_num, causal, is_decoder, @@ -307,3 +293,17 @@ void CascadeAppendAttentionKernel( "cache_int4_zp]"); } } + +inline uint32_t get_max_partition_size() { + static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size"); + static const uint32_t max_partition_size = + max_partition_size_env == nullptr ? 128 : std::stoul(std::string(max_partition_size_env)); + return max_partition_size; +} + +inline uint32_t get_encoder_max_partition_size() { + static const char* encoder_max_partition_size_env = std::getenv("FLAGS_cascade_encoder_attention_max_partition_size"); + static const uint32_t encoder_max_partition_size = + encoder_max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(encoder_max_partition_size_env)); + return encoder_max_partition_size; +} \ No newline at end of file diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu index 7dafef74ba88..79ba5cd7bc85 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu @@ -49,8 +49,6 @@ template void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu index 806eecbb529d..09e149c25233 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu index c677686d68aa..648d301880b8 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu index 75c6e80c3056..a3f0c95f02e2 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC4Kernel const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu index 065834d6d0d8..63b03741b0e7 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu index 3a2b13a89045..aae73a837de4 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu @@ -49,8 +49,6 @@ template void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu index 4f5dedb15dc5..57c5e36fca93 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu index 606c9128a973..e5d85cad2b5e 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu @@ -50,8 +50,6 @@ CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu index efc54738fafc..e115efacf907 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu index 83728df8d409..cfa10da809da 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu index 35267a59f55b..842fb6415fca 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu @@ -48,8 +48,6 @@ template void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, - const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, const bool is_decoder, diff --git a/csrc/gpu/helper.h b/csrc/gpu/helper.h index 4e8aa488141a..7ce33a017ec2 100644 --- a/csrc/gpu/helper.h +++ b/csrc/gpu/helper.h @@ -221,3 +221,17 @@ __device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids, int l } return flag; } + +inline uint32_t get_decoder_block_shape_q() { + static const char* decoder_block_shape_q_env = std::getenv("FLAGS_flag_dec_block_shape_q"); + static const uint32_t decoder_block_shape_q = + decoder_block_shape_q_env == nullptr ? 16 : std::stoi(std::string(decoder_block_shape_q_env)); + return decoder_block_shape_q; +} + +inline uint32_t get_encoder_block_shape_q() { + static const char* encoder_block_shape_q_env = std::getenv("FLAGS_flag_block_shape_q"); + static const uint32_t encoder_block_shape_q = + encoder_block_shape_q_env == nullptr ? 64 : std::stoi(std::string(encoder_block_shape_q_env)); + return encoder_block_shape_q; +} \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index cdf5730c7a86..8810c3fe19e4 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -1033,10 +1033,6 @@ def forward( kwargs["max_dec_len_this_time"] = max_dec_len_this_time if self.config.append_attn: - kwargs["encoder_block_shape_q"] = 64 - kwargs["decoder_block_shape_q"] = 16 - kwargs["max_partition_size"] = 32768 - kwargs["encoder_max_partition_size"] = 32768 from paddlenlp_ops import get_block_shape_and_split_kv_block @@ -1057,8 +1053,6 @@ def forward( max_enc_len_this_time, kwargs.get("seq_lens_this_time", None), kwargs.get("cum_offsets", None), - kwargs.get("encoder_block_shape_q", 64), - kwargs.get("decoder_block_shape_q", 16), self.num_heads // self.kv_num_heads, kwargs.get("block_size", 64), self.config.speculate_config.speculate_max_draft_token_num, @@ -2197,10 +2191,6 @@ def compute_attn( 0.0, 0.0, 0.0, # out_linear_in_scale - kwargs.get("encoder_block_shape_q", 64), - kwargs.get("decoder_block_shape_q", 16), - kwargs.get("max_partition_size", 32768), - kwargs.get("encoder_max_partition_size", 32768), self.config.speculate_config.speculate_max_draft_token_num, True, # causal self.config.speculate_config.speculate_method is not None, # speculate_decoder @@ -2395,10 +2385,6 @@ def compute_attn( self.quant_max_bound, self.quant_min_bound, self.act_scales["out_linear_in_scale"][i], - kwargs.get("encoder_block_shape_q", 64), - kwargs.get("decoder_block_shape_q", 16), - kwargs.get("max_partition_size", 32768), - kwargs.get("encoder_max_partition_size", 32768), self.config.speculate_config.speculate_max_draft_token_num, True, # causal self.config.speculate_config.speculate_method is not None, # speculate_decoder @@ -2760,10 +2746,6 @@ def compute_attn( self.quant_max_bound, self.quant_min_bound, self.act_scales["out_linear_in_scale"][i], - kwargs.get("encoder_block_shape_q", 64), - kwargs.get("decoder_block_shape_q", 16), - kwargs.get("max_partition_size", 32768), - kwargs.get("encoder_max_partition_size", 32768), self.config.speculate_config.speculate_max_draft_token_num, True, # causal False, # speculate_decoder