Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added FLAGS to replace four params and the value can be adjusted for better speedup #9624

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 2 additions & 45 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
Expand All @@ -76,7 +72,8 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
int max_enc_len_this_time_data = max_enc_len_this_time.data<int>()[0];
int max_dec_len_this_time_data = max_dec_len_this_time.data<int>()[0];
int max_len_kv_data = max_len_kv.data<int>()[0];

const int encoder_block_shape_q = get_encoder_block_shape_q();
const int decoder_block_shape_q = get_decoder_block_shape_q();
auto main_stream = qkv.stream();
static cudaEvent_t main_event;
static cudaEvent_t decoder_event;
Expand Down Expand Up @@ -209,8 +206,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
false,
Expand Down Expand Up @@ -248,8 +243,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
false,
Expand Down Expand Up @@ -292,8 +285,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
false,
Expand Down Expand Up @@ -440,8 +431,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
!speculate_decoder,
Expand Down Expand Up @@ -479,8 +468,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
!speculate_decoder,
Expand Down Expand Up @@ -524,8 +511,6 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
!speculate_decoder,
Expand Down Expand Up @@ -583,10 +568,6 @@ std::vector<paddle::Tensor> AppendAttention(
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
Expand Down Expand Up @@ -648,10 +629,6 @@ std::vector<paddle::Tensor> AppendAttention(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
Expand Down Expand Up @@ -698,10 +675,6 @@ std::vector<paddle::Tensor> AppendAttention(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
Expand Down Expand Up @@ -749,10 +722,6 @@ std::vector<paddle::Tensor> AppendAttention(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
Expand Down Expand Up @@ -798,10 +767,6 @@ std::vector<paddle::Tensor> AppendAttention(
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
encoder_block_shape_q,
decoder_block_shape_q,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
speculate_decoder);
Expand Down Expand Up @@ -903,10 +868,6 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool speculate_decoder) {
Expand Down Expand Up @@ -983,10 +944,6 @@ PD_BUILD_OP(append_attention)
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"max_partition_size: int",
"encoder_max_partition_size: int",
"speculate_max_draft_token_num: int",
"causal: bool",
"speculate_decoder: bool"})
Expand Down
14 changes: 4 additions & 10 deletions csrc/gpu/append_attn/append_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -786,8 +786,6 @@ void MultiQueryAppendAttention(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
Expand Down Expand Up @@ -839,9 +837,9 @@ void MultiQueryAppendAttention(
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);

uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
uint32_t chunk_size = get_max_partition_size();
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
chunk_size = get_encoder_max_partition_size();
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
Expand Down Expand Up @@ -1058,9 +1056,9 @@ void MultiQueryAppendAttention(
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);

uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
uint32_t chunk_size = get_max_partition_size();
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
chunk_size = get_encoder_max_partition_size();
}
const int num_chunks = div_up(max_dec_len, chunk_size);

Expand Down Expand Up @@ -1301,8 +1299,6 @@ void CascadeAppendAttentionC16Kernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -1363,8 +1359,6 @@ void CascadeAppendAttentionC16Kernel(
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
Expand Down
14 changes: 4 additions & 10 deletions csrc/gpu/append_attn/append_attention_c4_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,6 @@ void MultiQueryAppendC4Attention(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
Expand Down Expand Up @@ -1036,9 +1034,9 @@ void MultiQueryAppendC4Attention(
const float ratio = static_cast<float>(num_blocks_need) /
static_cast<float>(num_blocks_per_wave);

uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
uint32_t chunk_size = get_max_partition_size();
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
chunk_size = get_encoder_max_partition_size();
}
const int num_chunks = div_up(max_dec_len, chunk_size);

Expand Down Expand Up @@ -1282,9 +1280,9 @@ void MultiQueryAppendC4Attention(
static_cast<float>(num_blocks_per_wave);


uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
static uint32_t chunk_size = get_max_partition_size();
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
chunk_size = get_encoder_max_partition_size();
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
Expand Down Expand Up @@ -1538,8 +1536,6 @@ void CascadeAppendAttentionC4Kernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -1604,8 +1600,6 @@ void CascadeAppendAttentionC4Kernel(
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
Expand Down
14 changes: 4 additions & 10 deletions csrc/gpu/append_attn/append_attention_c8_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -860,8 +860,6 @@ void MultiQueryAppendC8Attention(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool is_decoder,
cudaStream_t &stream,
Expand Down Expand Up @@ -914,9 +912,9 @@ void MultiQueryAppendC8Attention(
const int dev_id = 0;
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
uint32_t chunk_size = get_max_partition_size();
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
chunk_size = get_encoder_max_partition_size();
}
const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
Expand Down Expand Up @@ -1136,9 +1134,9 @@ void MultiQueryAppendC8Attention(
const int dev_id = 0;
int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
uint32_t chunk_size = get_max_partition_size();
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
chunk_size = get_encoder_max_partition_size();
}

const int num_chunks = div_up(max_dec_len, chunk_size);
Expand Down Expand Up @@ -1377,8 +1375,6 @@ void CascadeAppendAttentionC8Kernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -1441,8 +1437,6 @@ void CascadeAppendAttentionC8Kernel(
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
is_decoder,
stream,
Expand Down
28 changes: 14 additions & 14 deletions csrc/gpu/append_attn/append_attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ void CascadeAppendAttentionC16Kernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -97,8 +95,6 @@ void CascadeAppendAttentionC8Kernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -142,8 +138,6 @@ void CascadeAppendAttentionC4Kernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -188,8 +182,6 @@ void CascadeAppendAttentionKernel(
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
const int max_partition_size,
const int encoder_max_partition_size,
const int speculate_max_draft_token_num,
const bool causal,
const bool is_decoder,
Expand Down Expand Up @@ -223,8 +215,6 @@ void CascadeAppendAttentionKernel(
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
Expand Down Expand Up @@ -258,8 +248,6 @@ void CascadeAppendAttentionKernel(
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
Expand Down Expand Up @@ -293,8 +281,6 @@ void CascadeAppendAttentionKernel(
quant_max_bound,
quant_min_bound,
in_scale,
max_partition_size,
encoder_max_partition_size,
speculate_max_draft_token_num,
causal,
is_decoder,
Expand All @@ -307,3 +293,17 @@ void CascadeAppendAttentionKernel(
"cache_int4_zp]");
}
}

inline uint32_t get_max_partition_size() {
static const char* max_partition_size_env = std::getenv("FLAGS_cascade_attention_max_partition_size");
static const uint32_t max_partition_size =
max_partition_size_env == nullptr ? 128 : std::stoul(std::string(max_partition_size_env));
return max_partition_size;
}

inline uint32_t get_encoder_max_partition_size() {
static const char* encoder_max_partition_size_env = std::getenv("FLAGS_cascade_encoder_attention_max_partition_size");
static const uint32_t encoder_max_partition_size =
encoder_max_partition_size_env == nullptr ? 32768 : std::stoul(std::string(encoder_max_partition_size_env));
return encoder_max_partition_size;
}
Loading