diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 51fecd07b5..07e07d464d 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -355,6 +355,9 @@ def api(self) -> str: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) @dataclass @@ -489,7 +492,8 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: pipelines = [] if dtype in ['fp16', 'bf16']: for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): - if hdim == 256: + # if hdim=32, fallback to 'qr' pipeline to workaround rocm 6.2 compiler problem (missing s_waitcnt) + if hdim == 256 or hdim == 32: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) @@ -497,11 +501,18 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) - if receipt == 1: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index b2f1f790ae..5b1aa08ebb 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -34,234 +34,338 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz return r; } +namespace impl { +// below type indicate the data type used for buffer load inline asm +// clang-format off +template struct buffer_load_trait; + +template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; +template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; +template struct buffer_load_trait<4 , T> { using payload_t = float; }; +template struct buffer_load_trait<2 , T> { using payload_t = float; }; +template struct buffer_load_trait<1 , T> { using payload_t = float; }; + +#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA +template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; +template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; +template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; +#endif +// clang-format on +} // namespace impl + // TODO: glc/slc/... -template +template struct buffer_load; #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) -template <> -struct buffer_load<16> +template +struct buffer_load<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; - asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<8> +template +struct buffer_load<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; - asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<4> +template +struct buffer_load<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<2> +template +struct buffer_load<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = float; - asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template <> -struct buffer_load<1> +template +struct buffer_load<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0) + index_t /*flag*/ = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); + else + asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; -template +template struct buffer_load_if; -template <> -struct buffer_load_if<16> +template +struct buffer_load_if<16, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<8> +template +struct buffer_load_if<8, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x2_t; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<4> +template +struct buffer_load_if<4, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<2> +template +struct buffer_load_if<2, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; -template <> -struct buffer_load_if<1> +template +struct buffer_load_if<1, pre_nop> { template CK_TILE_DEVICE void operator()(T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile( - "v_cmpx_le_u32 exec, 1, %5\n" - "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + else + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); } }; #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" @@ -275,17 +379,16 @@ struct buffer_store<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; - asm volatile( - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -296,17 +399,16 @@ struct buffer_store<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; - asm volatile( - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -317,17 +419,16 @@ struct buffer_store<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_dword %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -338,17 +439,16 @@ struct buffer_store<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 2); using mbuf_t = short; - asm volatile( - "buffer_store_short %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -359,17 +459,16 @@ struct buffer_store<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t /*flag*/ = 1) { static_assert(sizeof(T) == 4); using mbuf_t = float; - asm volatile( - "buffer_store_byte %0, %1, %2, %3 offen offset:%4" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) - : "memory"); + asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) + : "memory"); } }; @@ -383,21 +482,20 @@ struct buffer_store_if<16> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -412,7 +510,7 @@ struct buffer_store_if<8> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { @@ -420,14 +518,13 @@ struct buffer_store_if<8> auto save_exec = __builtin_amdgcn_read_exec(); // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -442,21 +539,20 @@ struct buffer_store_if<4> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -471,21 +567,20 @@ struct buffer_store_if<2> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -500,21 +595,20 @@ struct buffer_store_if<1> CK_TILE_DEVICE void operator()(const T& value, int32x4_t res /*buffer resource*/, index_t v_offset, - index_t s_offset, + index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, index_t flag = 1) { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %5\n" - "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" - "s_mov_b64 exec %6" + asm volatile("v_cmpx_le_u32 exec, 1, %4\n" + "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" + "s_mov_b64 exec %5" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), - "s"(s_offset), "n"(i_offset), "v"(flag), "s"(save_exec) @@ -538,8 +632,9 @@ namespace impl{ template CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) { - static_for<0, b.size(), 1>{}([&](auto i){ - asm volatile(" " : : "v"(b.get(i)) : "memory"); + constexpr auto kSize = remove_cvref_t::size(); + static_for<0, kSize, 1>{}([&](auto i){ + asm volatile(" " : : "v"(b.get(number{})) : "memory"); }); } #if 1 @@ -769,6 +864,28 @@ llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); +// buffer store ui16 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x2(uint16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_ui16x4(uint16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, @@ -859,17 +976,26 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); -CK_TILE_DEVICE void async_buffer_load_dword(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0) +template +CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t /*soffset*/, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0, + bool_constant = {}) { - asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) - : "memory"); + if constexpr(pre_nop) + asm volatile("s_nop 4\n" + "buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); + else + asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "n"(ioffset) + : "memory"); } CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) @@ -1181,12 +1307,14 @@ CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffe template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t flag = 0) + index_t flag = 0, + bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, @@ -1195,32 +1323,46 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, using type = thread_buffer; if constexpr(oob_conditional_check) { - buffer_load_if{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load_if{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } else { - buffer_load{}( - dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + buffer_load{}(dst, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + 0, + flag, + bool_constant{}); } } template + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0) + index_t src_immediate_addr_offset = 0, + bool_constant = {}) { static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - async_buffer_load_dword(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset); + async_buffer_load_dword_v(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset, + 0, + bool_constant{}); } template src_thread_d (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); if constexpr(std::is_same::value) // fp32 @@ -1478,6 +1623,49 @@ CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_d static_cast(coherence)); } } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_ui16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(uint16_t), + static_cast(coherence)); + } + } else { using r_t = thread_buffer; @@ -1595,7 +1783,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_th { if constexpr(N == 2) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, @@ -1821,20 +2009,50 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, template + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, const T* p_src_wave, index_t src_thread_element_offset, index_t src_element_space_size, - index_t is_valid_element = 0) + index_t is_valid_element = 0, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - amd_buffer_load_raw_impl( - dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + index_t is_valid_element = 0, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_buffer_load_raw_impl( + dst, + src_wave_buffer_resource, + src_thread_addr_offset, + 0, + is_valid_element, + bool_constant{}); } // unfortunately async copy can not make sure invalid data is zero inside LDS @@ -1843,11 +2061,13 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, // buffer_load OOB still working. template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_element_space_size) + amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default, + bool pre_nop = false> +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + bool_constant = {}) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); @@ -1855,7 +2075,23 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); amd_async_buffer_load_impl( - smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); +} + +// This version support buffer resource as input arg +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, + const int32x4_t src_wave_buffer_resource, + index_t src_thread_element_offset, + bool_constant = {}) +{ + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0, bool_constant{}); } // buffer_store requires: diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 888f0e728f..e3291b8336 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load() " ::); } -CK_TILE_DEVICE void s_nop() +CK_TILE_DEVICE void s_nop(index_t cnt = 0) { #if 1 - asm volatile("\ - s_nop 0 \n \ - " ::); + asm volatile("s_nop %0" : : "n"(cnt) :); #else - __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_sched_barrier(cnt); #endif } diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 601aad19bd..93fe135012 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -18,6 +18,7 @@ #define __gfx11__ #endif +#include "hip/hip_version.h" #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -144,6 +145,15 @@ #define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 #endif +#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE +#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091) || \ + (HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1 +#else +#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0 +#endif +#endif + #ifndef CK_TILE_DEBUG_LOG #define CK_TILE_DEBUG_LOG 0 #endif @@ -167,7 +177,15 @@ #define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #endif +#ifndef CK_TILE_USE_PK_FP16_TILE_CAST +#define CK_TILE_USE_PK_FP16_TILE_CAST 0 +#endif + // TODO: better solve this inside compiler #ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #define CK_TILE_FMHA_FWD_FAST_EXP2 0 #endif + +#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA +#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1 +#endif diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index 96b38241c0..ddf6f97afa 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -68,6 +68,8 @@ struct buffer_view invalid_element_value_ = T{0}; CK_TILE_HOST_DEVICE constexpr buffer_view() - : p_data_{}, buffer_size_{}, invalid_element_value_{} + : p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + : p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0} { } CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size, T invalid_element_value) - : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + : p_data_{p_data}, + buffer_size_{buffer_size}, + cached_buf_res_{0}, + invalid_element_value_{invalid_element_value} { } + // this is non constexpr intentially (will call some intrinsic internally) + // Must call for buffers that need *_raw load/store + CK_TILE_HOST_DEVICE void init_raw() + { + cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type)); + } + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() { return address_space_enum::global; @@ -332,12 +345,15 @@ struct buffer_view>::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t& dst, + index_t i, + bool is_valid_element, + bool_constant = {}) const { constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -348,18 +364,21 @@ struct buffer_view, t_per_x, Coherence, oob_conditional_check>( - dst, p_data_, i, buffer_size_, is_valid_element); + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check, pre_nop>( + dst, cached_buf_res_, i, is_valid_element, bool_constant{}); } // i is offset of T, not X. i should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>::value, bool>::type = false> - CK_TILE_DEVICE constexpr auto - async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t* smem, + index_t i, + bool /*is_valid_element*/, + bool_constant = {}) const { // X is vector of T constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; @@ -370,8 +389,8 @@ struct buffer_view, t_per_x, Coherence>( - smem, p_data_, i, buffer_size_); + amd_async_buffer_load_with_oob_raw, t_per_x, Coherence>( + smem, cached_buf_res_, i, bool_constant{}); } // i is offset of T, not X. i should be aligned to X @@ -626,6 +645,8 @@ struct buffer_view + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto load_tile_raw(T& tile, const tile_window_with_static_distribution& tile_window, - bool_constant = {}) + bool_constant = {}, + bool_constant = {}) { - tile_window.load_raw(tile, bool_constant{}); + tile_window.load_raw(tile, bool_constant{}, bool_constant{}); } template + index_t NumCoord, + bool oob_conditional_check = true, + bool pre_nop = false> CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, const tile_window_with_static_distribution& tile_window) + NumCoord>& tile_window, + bool_constant = {}, + bool_constant = {}) { - return tile_window.async_load(lds_tile); + return tile_window.async_load_raw( + lds_tile, bool_constant{}, bool_constant{}); } CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp index 89806203ab..9707f2990a 100644 --- a/include/ck_tile/core/tensor/null_tile_window.hpp +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -35,6 +35,8 @@ struct null_tile_window CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + CK_TILE_DEVICE void init_raw() {} + WindowLengths window_lengths_; }; diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index e37bd806de..6d40916893 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -33,6 +33,8 @@ struct tensor_view { } + CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); } + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() @@ -82,30 +84,34 @@ struct tensor_view // "coord" is coordinate of DataType, not X. "coord" should be aligned to X template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE void - get_vectorized_elements_raw(remove_cvref_t& dst, - const TensorCoord& coord, - bool_constant = {}) const + CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}, + bool_constant = {}) const { - return buf_.template get_raw( + return buf_.template get_raw( dst, coord.get_offset(), - coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); } template >::scalar_type, typename vector_traits>::scalar_type>, bool>::type = false> - CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, - const TensorCoord& coord) const + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw( + remove_cvref_t* smem, const TensorCoord& coord, bool_constant = {}) const { - return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + return buf_.template async_get_raw( + smem, coord.get_offset(), true /*not used*/, bool_constant{}); } // X is vector of DataType. diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 48762b7225..79018b9ced 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&) // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with // sub-dword tensor... -template -CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +template +CK_TILE_DEVICE void +set_tile(DstrTensors& dstr_tensor, number, bool_constant = {}) { - constexpr index_t tensor_bytes = - DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); - if constexpr(v == 0 && tensor_bytes % 4 == 0) + using elem_type = typename DstrTensors::DataType; + constexpr index_t elem_size = sizeof(elem_type); + + constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size; + + // # bytes per write = 4 + if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt) { +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + auto& buffer = dstr_tensor.get_thread_buffer(); + + static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) { + if constexpr(elem_size == 1) + { + // # elements per write = 4 + constexpr auto values = ext_vector_t{0, 0, 0, 0}; + + buffer[i_write * 4 + 0] = values.x; + buffer[i_write * 4 + 1] = values.y; + buffer[i_write * 4 + 2] = values.z; + buffer[i_write * 4 + 3] = values.w; + } + else if constexpr(elem_size == 2) + { + // # elements per write = 2 + constexpr auto values = ext_vector_t{0, 0}; + + buffer[i_write * 2 + 0] = values.x; + buffer[i_write * 2 + 1] = values.y; + } + else if constexpr(elem_size == 4) + { + // # elements per write = 1 + constexpr elem_type value = 0; + + buffer[i_write] = value; + } + else + { + static_assert(false, "type not supported"); + } + }); +#else using dvec_t = array; auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); for(auto i = 0; i < tensor.size(); i++) tensor.get(i) = v; +#endif } else { - tile_elementwise_inout( - [](auto& x) { x = type_convert(v); }, - dstr_tensor); + tile_elementwise_inout([](auto& x) { x = type_convert(v); }, + dstr_tensor); } } @@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) namespace impl { // TODO: this is ugly template -CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) +CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) { #if defined(__gfx94__) // This API is designed to use the _pk_ serious of function @@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) #endif } +template +CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) +{ +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 2 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + // TODO: this is rtz cvt, need be very careful + for(index_t i = 0; i < thread_buffer_size_pk; i++) + { + auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0], + in_dstr_tensors.get_thread_buffer()[2 * i + 1]); + + out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x; + out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y; + } + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) float> && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { - return impl::cast_tile_pk_fp8x4(src_tensor); + return impl::cast_tile_pk_fp8_fp32(src_tensor); } +#if CK_TILE_USE_PK_FP16_TILE_CAST + else if constexpr(std::is_same_v && + std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 2 == 0)) + { + return impl::cast_tile_pk_fp16_fp32(src_tensor); + } +#endif #if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 09a4eb1fc0..a080e12b94 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -344,9 +344,10 @@ struct tile_window_with_static_distribution return dst_tensor; } - template + template CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, - bool_constant = {}) const + bool_constant = {}, + bool_constant = {}) const { using Traits = load_store_traits; @@ -373,7 +374,13 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // data index [y0, y1, ...] constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); @@ -384,8 +391,12 @@ struct tile_window_with_static_distribution get_bottom_tensor_view().template get_vectorized_elements_raw( dst_vec_tbuf.template at(), bottom_tensor_thread_coord, - bool_constant{}); - + bool_constant{}, + pre_nop_); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + asm volatile( + ""); // this is starting from rocm-6.2, but same sympton, reuse this flag +#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { @@ -399,12 +410,17 @@ struct tile_window_with_static_distribution } }); }); +#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE + asm volatile("; this inline asm is workaround to prevent compiler from using too much " + "scratch memory" ::); +#endif } // TODO: currently async load only implemented in inline asm - template - CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, - bool_constant = {}) const + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + bool_constant = {}, + bool_constant = {}) const { using LdsTileWindow = remove_cvref_t; // using LdsTensorView = typename LdsTileWindow::BottomTensorView; @@ -449,11 +465,17 @@ struct tile_window_with_static_distribution auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { - constexpr auto iAccess = number{}; + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); // read from bottom tensor - get_bottom_tensor_view().template async_get_vectorized_elements( - smem, bottom_tensor_thread_coord); + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, pre_nop_); // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) @@ -608,6 +630,67 @@ struct tile_window_with_static_distribution }); } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + // this is the bottom tensor view // [x0', x1', ...] ==> [offset] BottomTensorView bottom_tensor_view_; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 8a19deb02a..d99d53c7d0 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -78,6 +78,12 @@ struct BlockFmhaPipelineQRKSVSAsync return Problem::kBlockPerCu; else { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS) + { + return 1; + } + if constexpr(kK0BlockLength <= 32) { if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && @@ -212,11 +218,14 @@ struct BlockFmhaPipelineQRKSVSAsync q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_origin(), Policy::template MakeQDramTileDistribution()); + q_dram_window.init_raw(); // TODO: we use async Copy for K, which is inline asm // a side effect is we have to use inline asm for q as well auto q = decltype(load_tile(q_dram_window)){}; - set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch load_tile_raw(q, q_dram_window); __builtin_amdgcn_sched_barrier(0); @@ -285,6 +294,16 @@ struct BlockFmhaPipelineQRKSVSAsync k_dram_block_window.get_window_origin(), Policy::template MakeKDramTileDistribution()); // K DRAM tile window for // load + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS))) + return bool_constant{}; + else + return bool_constant{}; + }(); + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = make_tile_window( bias_dram_block_window_tmp.get_bottom_tensor_view(), @@ -299,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync Policy::template MakeVDramTileDistribution()); // prefetch K tile - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); __builtin_amdgcn_sched_barrier(0); @@ -322,7 +341,9 @@ struct BlockFmhaPipelineQRKSVSAsync { static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { async_load_tile_raw(k_lds_store(number{})>{}), - k_dram_window); + k_dram_window, + k_oob_ck, + k_pre_np); if constexpr(i_k0 < k0_loops - 1) move_tile_window(k_dram_window, {0, kK0}); @@ -609,16 +630,13 @@ struct BlockFmhaPipelineQRKSVSAsync { // move K tile windows move_tile_window(k_dram_block_window, {kN0, 0}); - k_dram_window = - make_tile_window(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - Policy::template MakeKDramTileDistribution()); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); - async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window); + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np); move_tile_window(k_dram_window, {0, kK0}); } // tail