From 8e0ba3013cf658aa5791b41f80d392ab14f9a260 Mon Sep 17 00:00:00 2001 From: zhink <771809832@qq.com> Date: Thu, 12 Dec 2024 11:13:28 +0800 Subject: [PATCH] add FLAGS instead max_partition_size --- csrc/gpu/append_attention.cu | 14 -------------- csrc/gpu/append_attn/append_attention_c16_impl.cuh | 7 ++----- csrc/gpu/append_attn/append_attention_c4_impl.cuh | 7 ++----- csrc/gpu/append_attn/append_attention_c8_impl.cuh | 7 ++----- csrc/gpu/append_attn/append_attention_kernel.h | 14 +++++++------- ...ppend_attention_c16_bfloat16_bfloat16_kernel.cu | 1 - .../append_attention_c16_bfloat16_fp8_kernel.cu | 1 - .../append_attention_c16_bfloat16_int8_kernel.cu | 1 - .../append_attention_c16_float16_float16_kernel.cu | 1 - .../append_attention_c16_float16_fp8_kernel.cu | 1 - .../append_attention_c16_float16_int8_kernel.cu | 1 - ...append_attention_c4_bfloat16_bfloat16_kernel.cu | 1 - .../append_attention_c4_bfloat16_fp8_kernel.cu | 1 - .../append_attention_c4_bfloat16_int8_kernel.cu | 1 - .../append_attention_c4_float16_float16_kernel.cu | 1 - .../append_attention_c4_float16_fp8_kernel.cu | 1 - .../append_attention_c4_float16_int8_kernel.cu | 1 - ...append_attention_c8_bfloat16_bfloat16_kernel.cu | 1 - .../append_attention_c8_bfloat16_fp8_kernel.cu | 1 - .../append_attention_c8_bfloat16_int8_kernel.cu | 1 - .../append_attention_c8_float16_float16_kernel.cu | 1 - .../append_attention_c8_float16_fp8_kerne.cu | 1 - .../append_attention_c8_float16_int8_kerne.cu | 1 - .../transformers/fused_transformer_layers.py | 4 ---- 24 files changed, 13 insertions(+), 58 deletions(-) diff --git a/csrc/gpu/append_attention.cu b/csrc/gpu/append_attention.cu index f80f8cee5d3d..e33ceede61e3 100644 --- a/csrc/gpu/append_attention.cu +++ b/csrc/gpu/append_attention.cu @@ -61,7 +61,6 @@ std::vector AppendAttentionKernel( const float out_linear_in_scale, const int encoder_block_shape_q, const int decoder_block_shape_q, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -209,7 +208,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -248,7 +246,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -292,7 +289,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -440,7 +436,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -479,7 +474,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -524,7 +518,6 @@ std::vector AppendAttentionKernel( quant_max_bound, quant_min_bound, out_linear_in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -585,7 +578,6 @@ std::vector AppendAttention( const float out_linear_in_scale, const int encoder_block_shape_q, const int decoder_block_shape_q, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -650,7 +642,6 @@ std::vector AppendAttention( out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -700,7 +691,6 @@ std::vector AppendAttention( out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -751,7 +741,6 @@ std::vector AppendAttention( out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -800,7 +789,6 @@ std::vector AppendAttention( out_linear_in_scale, encoder_block_shape_q, decoder_block_shape_q, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -905,7 +893,6 @@ std::vector AppendAttentionInferDtype( const float out_linear_in_scale, const int encoder_block_shape_q, const int decoder_block_shape_q, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -985,7 +972,6 @@ PD_BUILD_OP(append_attention) "out_linear_in_scale: float", "encoder_block_shape_q: int", "decoder_block_shape_q: int", - "max_partition_size: int", "encoder_max_partition_size: int", "speculate_max_draft_token_num: int", "causal: bool", diff --git a/csrc/gpu/append_attn/append_attention_c16_impl.cuh b/csrc/gpu/append_attn/append_attention_c16_impl.cuh index ed181836d73c..7f0771c186e4 100644 --- a/csrc/gpu/append_attn/append_attention_c16_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c16_impl.cuh @@ -786,7 +786,6 @@ void MultiQueryAppendAttention( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool is_decoder, @@ -839,7 +838,7 @@ void MultiQueryAppendAttention( int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } @@ -1058,7 +1057,7 @@ void MultiQueryAppendAttention( int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } @@ -1301,7 +1300,6 @@ void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -1363,7 +1361,6 @@ void CascadeAppendAttentionC16Kernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, is_decoder, diff --git a/csrc/gpu/append_attn/append_attention_c4_impl.cuh b/csrc/gpu/append_attn/append_attention_c4_impl.cuh index 586bde4dc741..a4c8206e6826 100644 --- a/csrc/gpu/append_attn/append_attention_c4_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c4_impl.cuh @@ -973,7 +973,6 @@ void MultiQueryAppendC4Attention( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool is_decoder, @@ -1036,7 +1035,7 @@ void MultiQueryAppendC4Attention( const float ratio = static_cast(num_blocks_need) / static_cast(num_blocks_per_wave); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } @@ -1282,7 +1281,7 @@ void MultiQueryAppendC4Attention( static_cast(num_blocks_per_wave); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } @@ -1538,7 +1537,6 @@ void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -1604,7 +1602,6 @@ void CascadeAppendAttentionC4Kernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, is_decoder, diff --git a/csrc/gpu/append_attn/append_attention_c8_impl.cuh b/csrc/gpu/append_attn/append_attention_c8_impl.cuh index d5d1cc38e1b4..2f124ccf9502 100644 --- a/csrc/gpu/append_attn/append_attention_c8_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c8_impl.cuh @@ -860,7 +860,6 @@ void MultiQueryAppendC8Attention( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool is_decoder, @@ -914,7 +913,7 @@ void MultiQueryAppendC8Attention( const int dev_id = 0; int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } @@ -1136,7 +1135,7 @@ void MultiQueryAppendC8Attention( const int dev_id = 0; int sm_count; cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - uint32_t chunk_size = static_cast(max_partition_size); + static uint32_t chunk_size = get_max_partition_size(); if (!is_decoder) { chunk_size = static_cast(encoder_max_partition_size); } @@ -1377,7 +1376,6 @@ void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -1441,7 +1439,6 @@ void CascadeAppendAttentionC8Kernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, is_decoder, diff --git a/csrc/gpu/append_attn/append_attention_kernel.h b/csrc/gpu/append_attn/append_attention_kernel.h index 59532b2400c5..27e0f23debb3 100644 --- a/csrc/gpu/append_attn/append_attention_kernel.h +++ b/csrc/gpu/append_attn/append_attention_kernel.h @@ -52,7 +52,6 @@ void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -97,7 +96,6 @@ void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -142,7 +140,6 @@ void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -188,7 +185,6 @@ void CascadeAppendAttentionKernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, @@ -223,7 +219,6 @@ void CascadeAppendAttentionKernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -258,7 +253,6 @@ void CascadeAppendAttentionKernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -293,7 +287,6 @@ void CascadeAppendAttentionKernel( quant_max_bound, quant_min_bound, in_scale, - max_partition_size, encoder_max_partition_size, speculate_max_draft_token_num, causal, @@ -307,3 +300,10 @@ void CascadeAppendAttentionKernel( "cache_int4_zp]"); } } + +inline uint32_t get_max_partition_size() { + static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size"); + static const uint32_t max_partition_size = + max_partition_size_env == nullptr ? 128 : std::stoul(std::string(max_partition_size_env)); + return max_partition_size; +} diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu index 7dafef74ba88..e5bd22e19b9b 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_bfloat16_bfloat16_kernel.cu @@ -49,7 +49,6 @@ template void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu index 806eecbb529d..037b188df6e9 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_float16_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu index c677686d68aa..dafc03d3dd97 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c16_float16_fp8_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC16Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu index 75c6e80c3056..d9fc43812499 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_bfloat16_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC4Kernel const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu index 065834d6d0d8..b4ad573de5ce 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_bfloat16_fp8_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu index 3a2b13a89045..77b3791526bb 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_float16_kernel.cu @@ -49,7 +49,6 @@ template void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu index 4f5dedb15dc5..af28876b3821 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c4_float16_fp8_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC4Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu index 606c9128a973..a8e9b2b18b33 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_bfloat16_kernel.cu @@ -50,7 +50,6 @@ CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu index efc54738fafc..399b023b921d 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_bfloat16_fp8_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu index 83728df8d409..52e473555e49 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_float16_kernel.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu index 35267a59f55b..826890a8873c 100644 --- a/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu +++ b/csrc/gpu/append_attn/template_instantiation/append_attention_c8_float16_fp8_kerne.cu @@ -48,7 +48,6 @@ template void CascadeAppendAttentionC8Kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_partition_size, const int encoder_max_partition_size, const int speculate_max_draft_token_num, const bool causal, diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index cdf5730c7a86..c42dc08204b7 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -1035,7 +1035,6 @@ def forward( if self.config.append_attn: kwargs["encoder_block_shape_q"] = 64 kwargs["decoder_block_shape_q"] = 16 - kwargs["max_partition_size"] = 32768 kwargs["encoder_max_partition_size"] = 32768 from paddlenlp_ops import get_block_shape_and_split_kv_block @@ -2199,7 +2198,6 @@ def compute_attn( 0.0, # out_linear_in_scale kwargs.get("encoder_block_shape_q", 64), kwargs.get("decoder_block_shape_q", 16), - kwargs.get("max_partition_size", 32768), kwargs.get("encoder_max_partition_size", 32768), self.config.speculate_config.speculate_max_draft_token_num, True, # causal @@ -2397,7 +2395,6 @@ def compute_attn( self.act_scales["out_linear_in_scale"][i], kwargs.get("encoder_block_shape_q", 64), kwargs.get("decoder_block_shape_q", 16), - kwargs.get("max_partition_size", 32768), kwargs.get("encoder_max_partition_size", 32768), self.config.speculate_config.speculate_max_draft_token_num, True, # causal @@ -2762,7 +2759,6 @@ def compute_attn( self.act_scales["out_linear_in_scale"][i], kwargs.get("encoder_block_shape_q", 64), kwargs.get("decoder_block_shape_q", 16), - kwargs.get("max_partition_size", 32768), kwargs.get("encoder_max_partition_size", 32768), self.config.speculate_config.speculate_max_draft_token_num, True, # causal