From 6fec1cd0dfa53b4f5b228d1ff328ca6f99e0d8aa Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Tue, 10 Dec 2024 18:34:09 -0800 Subject: [PATCH 1/2] xe: sdpa: Allow non-transposed scalars for the KQ matmul --- src/gpu/intel/ocl/micro_sdpa.cl | 2 +- src/gpu/intel/ocl/micro_sdpa.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gpu/intel/ocl/micro_sdpa.cl b/src/gpu/intel/ocl/micro_sdpa.cl index bb4fb6c43cb..34d52331d43 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cl +++ b/src/gpu/intel/ocl/micro_sdpa.cl @@ -181,7 +181,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q, uint lda = DST_S2; #if KEY_SCALES || KEY_ZERO_POINTS - uint ldkq = div_up(d, KEY_GROUP_SIZE); + uint ldkq = KEY_D3; #endif #if VAL_SCALES || VAL_ZERO_POINTS uint ldvq = div_up(d, VAL_GROUP_SIZE); diff --git a/src/gpu/intel/ocl/micro_sdpa.cpp b/src/gpu/intel/ocl/micro_sdpa.cpp index d37bfd5cd98..0555ee2d1cc 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cpp +++ b/src/gpu/intel/ocl/micro_sdpa.cpp @@ -242,14 +242,14 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { auto scale_dt = key_scales_dt(); problem_kq.Ta_scale = jit::convert_dnnl_to_kernel_type(scale_dt); problem_kq.A_scale.alignment = uint8_t(types::data_type_size(scale_dt)); - problem_kq.A_scale.layout = MatrixLayout::T; + problem_kq.A_scale.layout = MatrixLayout::N; problem_kq.aScale2D = true; } if (with_key_zp()) { auto zp_dt = key_zp_dt(); problem_kq.Tao = jit::convert_dnnl_to_kernel_type(zp_dt); problem_kq.AO.alignment = uint8_t(types::data_type_size(zp_dt)); - problem_kq.AO.layout = MatrixLayout::T; + problem_kq.AO.layout = MatrixLayout::N; problem_kq.aoPtrDims = kq_common_zp ? 0 : 2; problem_kq.aOffset = ABOffset::Calc; } From e4f55336fd42f16536288cb5878a0bf7faea9565 Mon Sep 17 00:00:00 2001 From: Umar Arshad Date: Tue, 17 Dec 2024 17:38:32 -0800 Subject: [PATCH 2/2] xe: sdpa: Fix KQ gemm alignment for the K tensor --- src/gpu/intel/ocl/micro_sdpa.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gpu/intel/ocl/micro_sdpa.cpp b/src/gpu/intel/ocl/micro_sdpa.cpp index 0555ee2d1cc..e1076623805 100644 --- a/src/gpu/intel/ocl/micro_sdpa.cpp +++ b/src/gpu/intel/ocl/micro_sdpa.cpp @@ -264,7 +264,10 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) { problem_kq.B.layout = MatrixLayout::Pr; problem_kq.C.layout = MatrixLayout::T; - problem_kq.A.setAlignment(alignmentForLD(d->head_size() * problem.Ta)); + const memory_desc_wrapper key_mdw(key_md()); + auto ldk = static_cast( + gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); + problem_kq.A.setAlignment(alignmentForLD(ldk)); problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM problem_kq.B.crosspack = 2; problem_kq.B.tileR = into(d_max());