Skip to content

Commit

Permalink
[CK_TILE] Pick bugfixes for ROCm 6.2 compiler issues (#1430)
Browse files Browse the repository at this point in the history
  • Loading branch information
poyenc committed Aug 5, 2024
1 parent 00626ca commit 6659340
Show file tree
Hide file tree
Showing 11 changed files with 747 additions and 266 deletions.
23 changes: 17 additions & 6 deletions example/ck_tile/01_fmha/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def api(self) -> str:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)

@dataclass
Expand Down Expand Up @@ -489,19 +492,27 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
if hdim == 256:
# if hdim=32, fallback to 'qr' pipeline to workaround rocm 6.2 compiler problem (missing s_waitcnt)
if hdim == 256 or hdim == 32:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask))

pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
if receipt == 1:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
Expand Down
634 changes: 435 additions & 199 deletions include/ck_tile/core/arch/amd_buffer_addressing.hpp

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
" ::);
}

CK_TILE_DEVICE void s_nop()
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
{
#if 1
asm volatile("\
s_nop 0 \n \
" ::);
asm volatile("s_nop %0" : : "n"(cnt) :);
#else
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_sched_barrier(cnt);
#endif
}

Expand Down
18 changes: 18 additions & 0 deletions include/ck_tile/core/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define __gfx11__
#endif

#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
Expand Down Expand Up @@ -144,6 +145,15 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif

#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133)
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif

#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
Expand All @@ -167,7 +177,15 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif

#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif

// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif

#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
45 changes: 34 additions & 11 deletions include/ck_tile/core/tensor/buffer_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ struct buffer_view<address_space_enum::generic,
{
}

CK_TILE_HOST_DEVICE void init_raw() {}

CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::generic;
Expand Down Expand Up @@ -223,25 +225,36 @@ struct buffer_view<address_space_enum::global,

T* p_data_ = nullptr;
BufferSizeType buffer_size_;
int32x4_t cached_buf_res_;
remove_cvref_t<T> invalid_element_value_ = T{0};

CK_TILE_HOST_DEVICE constexpr buffer_view()
: p_data_{}, buffer_size_{}, invalid_element_value_{}
: p_data_{}, buffer_size_{}, cached_buf_res_{0}, invalid_element_value_{}
{
}

CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
: p_data_{p_data}, buffer_size_{buffer_size}, cached_buf_res_{0}, invalid_element_value_{0}
{
}

CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
: p_data_{p_data},
buffer_size_{buffer_size},
cached_buf_res_{0},
invalid_element_value_{invalid_element_value}
{
}

// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE void init_raw()
{
cached_buf_res_ = make_wave_buffer_resource(p_data_, buffer_size_ * sizeof(type));
}

CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::global;
Expand Down Expand Up @@ -332,12 +345,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
get_raw(remove_cvref_t<X>& dst, index_t i, bool is_valid_element) const
CK_TILE_DEVICE constexpr auto get_raw(remove_cvref_t<X>& dst,
index_t i,
bool is_valid_element,
bool_constant<pre_nop> = {}) const
{
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;

Expand All @@ -348,18 +364,21 @@ struct buffer_view<address_space_enum::global,

constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check>(
dst, p_data_, i, buffer_size_, is_valid_element);
amd_buffer_load_raw<remove_cvref_t<T>, t_per_x, Coherence, oob_conditional_check, pre_nop>(
dst, cached_buf_res_, i, is_valid_element, bool_constant<pre_nop>{});
}

// i is offset of T, not X. i should be aligned to X
template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE constexpr auto
async_get(remove_cvref_t<T>* smem, index_t i, bool /*is_valid_element*/) const
CK_TILE_DEVICE constexpr auto async_get_raw(remove_cvref_t<T>* smem,
index_t i,
bool /*is_valid_element*/,
bool_constant<pre_nop> = {}) const
{
// X is vector of T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
Expand All @@ -370,8 +389,8 @@ struct buffer_view<address_space_enum::global,

constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;

amd_async_buffer_load_with_oob<remove_cvref_t<T>, t_per_x, Coherence>(
smem, p_data_, i, buffer_size_);
amd_async_buffer_load_with_oob_raw<remove_cvref_t<T>, t_per_x, Coherence>(
smem, cached_buf_res_, i, bool_constant<pre_nop>{});
}

// i is offset of T, not X. i should be aligned to X
Expand Down Expand Up @@ -626,6 +645,8 @@ struct buffer_view<address_space_enum::lds,
{
}

CK_TILE_HOST_DEVICE void init_raw() {}

CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::lds;
Expand Down Expand Up @@ -908,6 +929,8 @@ struct buffer_view<address_space_enum::vgpr,
{
}

CK_TILE_HOST_DEVICE void init_raw() {}

CK_TILE_DEVICE static constexpr address_space_enum get_address_space()
{
return address_space_enum::vgpr;
Expand Down
19 changes: 13 additions & 6 deletions include/ck_tile/core/tensor/load_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,37 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {})
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{});
tile_window.load_raw(tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}

template <typename LdsTileWindow_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord>
index_t NumCoord,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
async_load_tile_raw(LdsTileWindow_&& lds_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window)
NumCoord>& tile_window,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load(lds_tile);
return tile_window.async_load_raw(
lds_tile, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}

CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
Expand Down
2 changes: 2 additions & 0 deletions include/ck_tile/core/tensor/null_tile_window.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ struct null_tile_window

CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; }

CK_TILE_DEVICE void init_raw() {}

WindowLengths window_lengths_;
};

Expand Down
24 changes: 15 additions & 9 deletions include/ck_tile/core/tensor/tensor_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct tensor_view
{
}

CK_TILE_HOST_DEVICE void init_raw() { buf_.init_raw(); }

CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; }

CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension()
Expand Down Expand Up @@ -82,30 +84,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE void
get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {}) const
CK_TILE_HOST_DEVICE void get_vectorized_elements_raw(remove_cvref_t<X>& dst,
const TensorCoord& coord,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
return buf_.template get_raw<X, oob_conditional_check>(
return buf_.template get_raw<X, oob_conditional_check, pre_nop>(
dst,
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<pre_nop>{});
}

template <typename X,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t<DataType>* smem,
const TensorCoord& coord) const
CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements_raw(
remove_cvref_t<DataType>* smem, const TensorCoord& coord, bool_constant<pre_nop> = {}) const
{
return buf_.template async_get<X>(smem, coord.get_offset(), true /*not used*/);
return buf_.template async_get_raw<X>(
smem, coord.get_offset(), true /*not used*/, bool_constant<pre_nop>{});
}

// X is vector of DataType.
Expand Down
Loading

0 comments on commit 6659340

Please sign in to comment.