diff --git a/python/codegen/mbb.py b/python/codegen/mbb.py index 541fbf11..f5487487 100644 --- a/python/codegen/mbb.py +++ b/python/codegen/mbb.py @@ -34,7 +34,9 @@ MC_INST_TYPE_GLOBAL_MEM = 3 MC_INST_TYPE_LEGACY_MACRO = 4 # like macro_c_clear_t. this is a hack MC_INST_TYPE_COMMENTS = 5 -MC_INST_TYPE_OTHER = 6 +MC_INST_TYPE_PREDEFINE_IF = 6 +MC_INST_TYPE_PREDEFINE_ENDIF = 7 +MC_INST_TYPE_OTHER = 8 def get_mc_inst_op(inst_str): istr = inst_str.strip() @@ -59,6 +61,10 @@ def mc_inst_is_global_mem(inst_op): return _check_inst_prefix(inst_op, ['global_', 'buffer_']) def mc_inst_is_legacy_macro(inst_op): return _check_inst_prefix(inst_op, ['.v_clear_nc']) +def mc_inst_is_predefine_if(inst_op): + return _check_inst_prefix(inst_op, ['.if']) +def mc_inst_is_predefine_endif(inst_op): + return _check_inst_prefix(inst_op, ['.endif']) def get_mc_inst_type(inst_str): ''' @@ -76,6 +82,10 @@ def get_mc_inst_type(inst_str): return MC_INST_TYPE_GLOBAL_MEM if mc_inst_is_legacy_macro(inst_op): return MC_INST_TYPE_LEGACY_MACRO + if mc_inst_is_predefine_if(inst_op): + return MC_INST_TYPE_PREDEFINE_IF + if mc_inst_is_predefine_endif(inst_op): + return MC_INST_TYPE_PREDEFINE_ENDIF return MC_INST_TYPE_OTHER class mc_inst_t(object): @@ -105,8 +115,19 @@ def create_mc_inst(inst_str): if mc_inst_is_legacy_macro(get_mc_inst_op(istr)): return mc_inst_t(inst_str) + def inst_in_directive_white_list(inst): + # TODO: with the .if .. .else, this should group into a single mbb + if istr[0] != '.': + return False + asm_directive_white_list = ['.if', '.ifdef', '.else', '.endif'] + for itm in asm_directive_white_list: + if inst.startswith(itm): + return True + return False + if istr[0] in (';', '/', '.', '\n'): # ignore comment, directive like .set, .macro - return None + if not inst_in_directive_white_list(istr): + return None # print(f'[XX] {istr[0]}, {inst_str}') return mc_inst_t(inst_str) @@ -172,10 +193,14 @@ def create_machine_basic_block(multi_line_inst_str, **option): class parse_mbb_list_t(object): STATE_NORMAL = 0 STATE_PARSING_MBB = 1 + STATE_PARSING_MBB_IN_PREDEFINE = 2 INST_MBB_START = ['v_cmpx'] INST_MBB_END = ['s_mov_b64 exec', 's_or_b64 exec'] + INST_MBB_START_PREDEFINE = ['.if'] + INST_MBB_END_PREDEFINE = ['.endif'] + def is_mbb_start_macro_c_clear(self, current_index, istrs_list): ''' special rule for macro_c_clear_t @@ -242,6 +267,22 @@ def is_mbb_end(self, istr): return True return False + def is_mbb_start_predefine(self, istr): + _istr = istr.strip() + _istr = re.sub(' +', ' ', _istr) # remove multiple space character + for ms in self.INST_MBB_START_PREDEFINE: + if _istr.startswith(ms): + return True + return False + + def is_mbb_end_predefine(self, istr): + _istr = istr.strip() + _istr = re.sub(' +', ' ', _istr) # remove multiple space character + for ms in self.INST_MBB_END_PREDEFINE: + if _istr.startswith(ms): + return True + return False + def parse(self, multi_line_inst_str, **option): def get_dict_with_default(dictionary, key, default_value): if key in dictionary: @@ -262,6 +303,18 @@ def match_group_mbb_by_end_of_inst_op(inst_op): return True return False assert False + + def match_group_mbb_by_end_of_inst_op_lookback(current_index, istrs_list): + if current_index <= 0: + return False + for prev_index in range(current_index - 1, -1, -1): + prev_istr = istrs_list[prev_index] + prev_mc_inst = create_mc_inst(prev_istr) + prev_inst_op = get_mc_inst_op(prev_istr) + if not prev_inst_op: + continue + return match_group_mbb_by_end_of_inst_op(prev_inst_op) + return False # nonthing to search istrs = multi_line_inst_str.split('\n') mbbs = list() @@ -295,15 +348,24 @@ def match_group_mbb_by_end_of_inst_op(inst_op): if group_mbb_by_end_of_inst_op: inst_op = get_mc_inst_op(istr) if state == self.STATE_NORMAL: - if match_group_mbb_by_end_of_inst_op(inst_op): - mbbs.append(machine_basic_block_t(copy.copy([mc_inst]))) + if self.is_mbb_start_predefine(istr): + mc_inst_buffer.append(mc_inst) + state = self.STATE_PARSING_MBB_IN_PREDEFINE + elif match_group_mbb_by_end_of_inst_op(inst_op): + mc_inst_buffer.append(mc_inst) + mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer))) + # for yy in mc_inst_buffer: + # print(f" +++inst:{yy()}") + # print(f" +++--------------------") mc_inst_buffer.clear() else: - #mc_inst_buffer.clear() mc_inst_buffer.append(mc_inst) state = self.STATE_PARSING_MBB - else: - if match_group_mbb_by_end_of_inst_op(inst_op): + elif state == self.STATE_PARSING_MBB: + if self.is_mbb_start_predefine(istr): + mc_inst_buffer.append(mc_inst) + state = self.STATE_PARSING_MBB_IN_PREDEFINE + elif match_group_mbb_by_end_of_inst_op(inst_op): mc_inst_buffer.append(mc_inst) mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer))) # print(f'xxxxx_ {mc_inst_buffer}, len:{len(mc_inst_buffer)}') @@ -311,12 +373,32 @@ def match_group_mbb_by_end_of_inst_op(inst_op): # print(f" inst:{yy()}") # machine_basic_block_t(copy.copy(mc_inst_buffer)).dump() state = self.STATE_NORMAL + #for yy in mc_inst_buffer: + # print(f" +++inst:{yy()}") + #print(f" +++--------------------") mc_inst_buffer.clear() else: mc_inst_buffer.append(mc_inst) + elif state == self.STATE_PARSING_MBB_IN_PREDEFINE: + if self.is_mbb_end_predefine(istr): + ''' + only switch back, but not cut the mc_inst_buffer into mbb here + ''' + mc_inst_buffer.append(mc_inst) + if match_group_mbb_by_end_of_inst_op_lookback(i, istrs): + mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer))) + #for yy in mc_inst_buffer: + # print(f" +++inst:{yy()}") + #print(f" +++--------------------") + mc_inst_buffer.clear() + state = self.STATE_NORMAL + else: + mc_inst_buffer.append(mc_inst) + else: + assert False else: if state == self.STATE_NORMAL: - if self.is_mbb_start(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs) \ + if self.is_mbb_start(istr) or self.is_mbb_start_predefine(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs) \ or self.is_mbb_start_bfe_and_cmpx_block(i, istrs): mc_inst_buffer.clear() mc_inst_buffer.append(mc_inst) @@ -331,7 +413,7 @@ def match_group_mbb_by_end_of_inst_op(inst_op): pass else: assert False, f'not support recursive start/end for now, with {i}:{istr}, {istrs}' - if self.is_mbb_end(istr): + if self.is_mbb_end(istr) or self.is_mbb_end_predefine(istr): mc_inst_buffer.append(mc_inst) mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer))) state = self.STATE_NORMAL @@ -347,9 +429,10 @@ def match_group_mbb_by_end_of_inst_op(inst_op): #assert len(mbbs) != 0, f"nonthing parsed from input inst: {multi_line_inst_str}" if len(mbbs) == 0: return list() # silently return empty list + #print('************************') #for y in mbbs: # y.dump() - #print('++++++++++++++++++++++++++++') + #print('************************') if dup_inst_per_mbb != "off": _dup_str = dup_inst_per_mbb.split(',') assert len(_dup_str) == 2 diff --git a/python/codegen/scheduler.py b/python/codegen/scheduler.py index e4d2c1f9..020e70e1 100644 --- a/python/codegen/scheduler.py +++ b/python/codegen/scheduler.py @@ -327,15 +327,23 @@ def emit_current_mbb1(c_mbb1): return self._get_deferred() if interleave_pattern == INTERLEAVE_PTN_1: + def check_share_mem_block(current_mbb): + if current_mbb.mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM: + return True + if current_mbb.length() >=2: + if current_mbb.mc_inst(-2).type() == MC_INST_TYPE_SHARE_MEM and \ + current_mbb.mc_inst(-1).type() == MC_INST_TYPE_PREDEFINE_ENDIF: + return True + return False mbb_0_mfma_cnt = 0 for m in mbb_0: if mbb_have_mfma(m): mbb_0_mfma_cnt += 1 - assert mbb_1[0].mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM + assert check_share_mem_block(mbb_1[0]) num_smem = 0 for i in range(len(mbb_1)): - if mbb_1[i].mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM: + if check_share_mem_block(mbb_1[i]): num_smem = num_smem + 1 else: pass @@ -353,7 +361,7 @@ def emit_current_mbb1(c_mbb1): break # print(f' --- inst:{mbb_1[m1_idx]()} === {m1_idx}/{len(mbb_1)}, {smem_per_interleave_cnt}/smem_per_interleave:{smem_per_interleave}') self._emit(self.call_mbb(mbb_1[m1_idx])) - if mbb_1[m1_idx].mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM: + if check_share_mem_block(mbb_1[m1_idx]): smem_per_interleave_cnt = smem_per_interleave_cnt + 1 m1_idx += 1 if smem_per_interleave_cnt >= smem_per_interleave: diff --git a/python/codegen_driver.py b/python/codegen_driver.py index 43a407a0..867018a1 100755 --- a/python/codegen_driver.py +++ b/python/codegen_driver.py @@ -109,6 +109,11 @@ def emit_global_macro(self): macro_c_clear_t(self.mc).emit() if self.mc.arch_config.use_dlops: self._emit_fma_macro() + if hasattr(self.kernel_list[0], 'use_bf16_1k_in_fp16'): + if self.kernel_list[0].use_bf16_1k_in_fp16(): + sym = self.kernel_list[0].get_predefine_for_bf16_1k_in_fp16() + dfv = self.kernel_list[0].get_predefine_for_bf16_1k_in_fp16_default_value() + inst_mfma_emit_macro_mfma_16f(self.mc, sym, dfv) def emit_global_macro_per_s_file(self, mc): # emit global macro, independent of tunable diff --git a/python/igemm/igemm_bwd_gtc_nhwc.py b/python/igemm/igemm_bwd_gtc_nhwc.py index a2977b5f..edce462a 100755 --- a/python/igemm/igemm_bwd_gtc_nhwc.py +++ b/python/igemm/igemm_bwd_gtc_nhwc.py @@ -33,6 +33,7 @@ IGEMM_BWD_GTC_NHWC_ACCVGPR_UNIFIED = True # used in gfx90a IGEMM_BWD_GTC_PACK_DUE_ITER_B16_LO_HI = True +IGEMM_BWD_GTC_NHWC_USE_BF16_1K_IN_FP16 = True # used in gfx90a def _find_non_1_index_in_list(list_object): result_list = list() @@ -129,7 +130,7 @@ def flatten(x): from functools import reduce return reduce(lambda a, b: a*b, x, 1) ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, - self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, self.tunable.precision) + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, self.tunable.precision, bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16()) self.xdlops_mapping = igemm_xdlops_mapping_t(self.mc, ctrl_xdlops_mapping) assert flatten(ctrl_xdlops_mapping.acc_c_per_thread_m()) % self.coalescing_store_groups == 0, \ f"coalescing store groups should be divided by agpr per thread in m direction {ctrl_xdlops_mapping.acc_c_per_thread_m()}" @@ -210,7 +211,19 @@ def get_vector_write_out(): self.vgpr = self.kernel_vgpr_t(mc, self) if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self.agpr = self.kernel_agpr_t(mc, self) - + + def use_bf16_1k_in_fp16(self): + if self.tunable.precision == 'fp16' and self.mc.arch_config.arch == AMDGPU_ARCH_GFX90A and IGEMM_BWD_GTC_NHWC_USE_BF16_1K_IN_FP16: + return True + else: + return False + + def get_predefine_for_bf16_1k_in_fp16(self): + return 'igemm_bwd_fp16_alt_impl' + + def get_predefine_for_bf16_1k_in_fp16_default_value(self): + return 1 + def name(self): return igemm_gtc_encode_kernel_name(self.tunable, self.mc.arch_config.arch) @@ -735,6 +748,12 @@ def __call__(self): v = self.outer.vgpr m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() with self._deferred_context(): + if self.outer.use_bf16_1k_in_fp16(): + m_packed_fp16_to_bf16 = macro_packed_fp16_to_bf16_t(self.mc, num_vgpr = self.outer.get_num_vgpr_global_load_a()) + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(m_packed_fp16_to_bf16(v.v_gld_a(), v.v_tmp(5))) + self._emit(f'.endif') self._emit(m_in_2d_shared_store(v.v_gld_a(), v.v_sst_a_os())) return self._get_deferred() @@ -752,7 +771,7 @@ def __call__(self): ta_nb0, ta_nb1, ta_e, ta_k, tb_e, tb_k, tb_c0, tb_c1 = self.outer.get_thread_lengths() m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() with self._deferred_context(): - self._emit(m_wei_2d_shared_store(v.v_gld_b(), v.v_sst_b_os(), *(v.v_pack_k_tmp(),) if self.outer.tunable.precision in ('fp16', 'bf16') and tb_k % 2 == 0 else ())) + self._emit(m_wei_2d_shared_store(v.v_gld_b(), v.v_sst_b_os(), *(v.v_pack_k_tmp(), v.v_tmp(4)) if self.outer.tunable.precision in ('fp16', 'bf16') and tb_k % 2 == 0 else ())) return self._get_deferred() class kernel_karg_t(mc_base_t): @@ -1443,13 +1462,15 @@ def get_macro_shared_store(self): out_sst_ctrl.stride_d1 = k_pack_src_mat * data_byte class macro_wei_sst_t(macro_base_t): - def __init__(self, mc): + def __init__(self, mc, outer): macro_base_t.__init__(self, mc, True) + self.outer = outer self.issue_cnt = 0 self.declare_arg("v_src") self.declare_arg("v_sst_os") if data_byte == 2 and tb_k % 2 == 0: self.declare_arg("v_pack_k_tmp") # need tb_k // 2 + self.declare_arg("v_tmp2") def name(self): return '' @@ -1472,7 +1493,16 @@ def expr(self): idx = i_k * num_tb_c + i_c k_r, k_p = i_k // k_pack_src_mat, i_k % k_pack_src_mat offset = k_r * stride_dk + i_c * stride_dc + k_p * data_byte - self._emit(ds_write(self.v_sst_os(), self.v_src(idx), offset)) + if self.outer.use_bf16_1k_in_fp16(): + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(f"v_cvt_f32_f16 v[{self.v_src(idx)}], v[{self.v_src(idx)}]") + self._emit(ds_write(self.v_sst_os(), self.v_src(idx), offset, 1)) + self._emit(f'.else') + self._emit(ds_write(self.v_sst_os(), self.v_src(idx), offset)) + self._emit(f'.endif') + else: + self._emit(ds_write(self.v_sst_os(), self.v_src(idx), offset)) self.issue_cnt = self.issue_cnt + ds_write.get_issues(offset) else: packed_k_dword = tb_k // 2 @@ -1482,9 +1512,21 @@ def expr(self): for i_pk in range(packed_k_dword): idx_0 = 2 * i_pk * dwords_per_c + i_c // 2 idx_1 = 2 * i_pk * dwords_per_c + i_c // 2 + dwords_per_c - op_sel = '' if i_c % 2 == 0 else ' op_sel:[1, 1]' - # print(f"i_pk:{i_pk}, i_c:{i_c}, idx_0:{idx_0}, idx_1:{idx_1}") - self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_pk)}], v[{self.v_src(idx_0)}], v[{self.v_src(idx_1)}]{op_sel}") + if self.outer.use_bf16_1k_in_fp16(): + src0_sel = '' if i_c % 2 == 0 else ' src0_sel:WORD_1' + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(f"v_cvt_f32_f16 v[{self.v_tmp2(0)}], v[{self.v_src(idx_0)}]{src0_sel}") + self._emit(f"v_cvt_f32_f16 v[{self.v_tmp2(1)}], v[{self.v_src(idx_1)}]{src0_sel}") + self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_pk)}], v[{self.v_tmp2(0)}], v[{self.v_tmp2(1)}] op_sel:[1, 1]") + self._emit(f'.else') + op_sel = '' if i_c % 2 == 0 else ' op_sel:[1, 1]' + self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_pk)}], v[{self.v_src(idx_0)}], v[{self.v_src(idx_1)}]{op_sel}") + self._emit(f'.endif') + else: + op_sel = '' if i_c % 2 == 0 else ' op_sel:[1, 1]' + # print(f"i_pk:{i_pk}, i_c:{i_c}, idx_0:{idx_0}, idx_1:{idx_1}") + self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_pk)}], v[{self.v_src(idx_0)}], v[{self.v_src(idx_1)}]{op_sel}") self._emit(ds_write(self.v_sst_os(), self.v_pack_k_tmp(), i_c * stride_dc)) self.issue_cnt = self.issue_cnt + ds_write.get_issues(i_c * stride_dc) @@ -1525,7 +1567,7 @@ def get_issues(self): inline = True if self.tunable.fma_interleave else False return macro_igemm_3d_shared_store_t(self.mc, out_sst_ctrl, inline) if not self.tunable.tensor_a_pass_through else None, \ - macro_wei_sst_t(self.mc) if not self.tunable.tensor_b_pass_through else None + macro_wei_sst_t(self.mc, self) if not self.tunable.tensor_b_pass_through else None def get_macro_move_slice_window(self): inline = True if self.tunable.fma_interleave else False @@ -2764,7 +2806,7 @@ def move_slice_window_acc(): self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, - self.tunable.precision) + self.tunable.precision, bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16()) fctrl.cxm = ctrl_xdlops_mapping fctrl.unroll_k = self.tunable.gemm_k_per_block fctrl.label_prefix = self.name() diff --git a/python/igemm/igemm_fwd_gtc_nhwc.py b/python/igemm/igemm_fwd_gtc_nhwc.py index 3216710d..40b00b4f 100755 --- a/python/igemm/igemm_fwd_gtc_nhwc.py +++ b/python/igemm/igemm_fwd_gtc_nhwc.py @@ -32,6 +32,7 @@ IGEMM_FWD_GTC_NHWC_PACK_IN_FLAG = 0 # IGEMM_FWD_GTC_NHWC_P_INTERLEAVE_GLD = False # p tensor interleave IGEMM_FWD_GTC_NHWC_ACCVGPR_UNIFIED = True # used in gfx90a +IGEMM_FWD_GTC_NHWC_USE_BF16_1K_IN_FP16 = True # used in gfx90a def _find_non_1_index_in_list(list_object): result_list = list() @@ -102,7 +103,7 @@ def flatten(x): from functools import reduce return reduce(lambda a, b: a*b, x, 1) ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, - self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, self.tunable.precision) + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, self.tunable.precision, bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16()) self.xdlops_mapping = igemm_xdlops_mapping_t(self.mc, ctrl_xdlops_mapping) assert flatten(ctrl_xdlops_mapping.acc_c_per_thread_m()) % self.coalescing_store_groups == 0, \ f"coalescing store groups should be divided by agpr per thread in m direction {ctrl_xdlops_mapping.acc_c_per_thread_m()}" @@ -184,6 +185,18 @@ def get_vector_write_out(): if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self.agpr = self.kernel_agpr_t(mc, self) + def use_bf16_1k_in_fp16(self): + if self.tunable.precision == 'fp16' and self.mc.arch_config.arch == AMDGPU_ARCH_GFX90A and IGEMM_FWD_GTC_NHWC_USE_BF16_1K_IN_FP16: + return True + else: + return False + + def get_predefine_for_bf16_1k_in_fp16(self): + return 'igemm_fwd_fp16_alt_impl' + + def get_predefine_for_bf16_1k_in_fp16_default_value(self): + return 0 + def name(self): return igemm_gtc_encode_kernel_name(self.tunable, self.mc.arch_config.arch) @@ -646,6 +659,12 @@ def __call__(self): v = self.outer.vgpr m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() with self._deferred_context(): + if self.outer.use_bf16_1k_in_fp16(): + m_packed_fp16_to_bf16 = macro_packed_fp16_to_bf16_t(self.mc, num_vgpr = self.outer.get_num_vgpr_global_load_a()) + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(m_packed_fp16_to_bf16(v.v_gld_a(), v.v_tmp(5))) + self._emit(f'.endif') self._emit(m_in_2d_shared_store(v.v_gld_a(), v.v_sst_a_os())) return self._get_deferred() @@ -662,6 +681,12 @@ def __call__(self): v = self.outer.vgpr m_in_2d_shared_store, m_wei_2d_shared_store = self.outer.get_macro_shared_store() with self._deferred_context(): + if self.outer.use_bf16_1k_in_fp16(): + m_packed_fp16_to_bf16 = macro_packed_fp16_to_bf16_t(self.mc, num_vgpr = self.outer.get_num_vgpr_global_load_b()) + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(m_packed_fp16_to_bf16(v.v_gld_b(), v.v_tmp(5))) + self._emit(f'.endif') self._emit(m_wei_2d_shared_store(v.v_gld_b(), v.v_sst_b_os())) return self._get_deferred() @@ -2287,7 +2312,7 @@ def move_slice_window_acc(): self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, - self.tunable.precision) + self.tunable.precision, bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16()) fctrl.cxm = ctrl_xdlops_mapping fctrl.unroll_k = self.tunable.gemm_k_per_block fctrl.label_prefix = self.name() @@ -2347,6 +2372,8 @@ def move_slice_window_acc(): fctrl.pass_through_b = self.tunable.tensor_b_pass_through fctrl.pass_through_a_v_pack = self.get_k_pack() fctrl.pass_through_b_v_pack = self.get_k_pack() + fctrl.pass_through_bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16() + fctrl.pass_through_bf16_1k_in_fp16_predefine = self.get_predefine_for_bf16_1k_in_fp16() fctrl.pass_through_a_interleave_gld = 1 if self.tunable.tensor_a_pass_through_interleave_gld else 0 fctrl.pass_through_b_interleave_gld = 1 if self.tunable.tensor_b_pass_through_interleave_gld else 0 diff --git a/python/igemm/igemm_wrw_gtc_nhwc.py b/python/igemm/igemm_wrw_gtc_nhwc.py index 7ab8d41f..71bb0a14 100755 --- a/python/igemm/igemm_wrw_gtc_nhwc.py +++ b/python/igemm/igemm_wrw_gtc_nhwc.py @@ -33,6 +33,7 @@ IGEMM_WRW_GTC_N_SPLIT_FIRST = 1 IGEMM_WRW_GTC_NHWC_ACCVGPR_UNIFIED = True # used in gfx90a +IGEMM_WRW_GTC_NHWC_USE_BF16_1K_IN_FP16 = True # used in gfx90a def _find_non_1_index_in_list(list_object): result_list = list() @@ -114,7 +115,7 @@ def flatten(x): from functools import reduce return reduce(lambda a, b: a*b, x, 1) ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, - self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, self.tunable.precision) + self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, self.tunable.precision, bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16()) self.xdlops_mapping = igemm_xdlops_mapping_t(self.mc, ctrl_xdlops_mapping) assert flatten(ctrl_xdlops_mapping.acc_c_per_thread_m()) % self.coalescing_store_groups == 0, \ f"coalescing store groups should be divided by agpr per thread in m direction {ctrl_xdlops_mapping.acc_c_per_thread_m()}" @@ -175,6 +176,17 @@ def get_vector_write_out(): if self.tunable.fma_type == IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: self.agpr = self.kernel_agpr_t(mc, self) + def use_bf16_1k_in_fp16(self): + if self.tunable.precision == 'fp16' and self.mc.arch_config.arch == AMDGPU_ARCH_GFX90A and IGEMM_WRW_GTC_NHWC_USE_BF16_1K_IN_FP16: + return True + else: + return False + + def get_predefine_for_bf16_1k_in_fp16(self): + return 'igemm_wrw_fp16_alt_impl' + + def get_predefine_for_bf16_1k_in_fp16_default_value(self): + return 1 def name(self): return igemm_gtc_encode_kernel_name(self.tunable, self.mc.arch_config.arch) @@ -377,9 +389,16 @@ def __call__(self): s = self.outer.sgpr v = self.outer.vgpr _, m_in_2d_shared_store = self.outer.get_macro_shared_store() + ta_k, ta_n, tb_n, tb_c = self.outer.get_thread_lengths() with self._deferred_context(): + if self.outer.use_bf16_1k_in_fp16() and (self.outer.tunable.precision == 'fp16' and ta_n == 1): + m_packed_fp16_to_bf16 = macro_packed_fp16_to_bf16_t(self.mc, num_vgpr = self.outer.get_num_vgpr_global_load_b()) + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(m_packed_fp16_to_bf16(v.v_gld_b(), v.v_tmp(5))) + self._emit(f'.endif') need_swizzle = self.outer.tunable.precision in ('fp16', 'bf16') and self.outer.tunable.tensor_b_thread_lengths[1] > 1 - self._emit(m_in_2d_shared_store(v.v_gld_b(), v.v_sst_b_os(), *(v.v_tmp(),) if need_swizzle else ())) + self._emit(m_in_2d_shared_store(v.v_gld_b(), v.v_sst_b_os(), *(v.v_tmp(),v.v_tmp(6)) if need_swizzle else ())) return self._get_deferred() class shared_store_out_t(mc_base_t): @@ -394,9 +413,16 @@ def __call__(self): s = self.outer.sgpr v = self.outer.vgpr m_out_2d_shared_store, _ = self.outer.get_macro_shared_store() + ta_k, ta_n, tb_n, tb_c = self.outer.get_thread_lengths() with self._deferred_context(): + if self.outer.use_bf16_1k_in_fp16() and (self.outer.tunable.precision == 'fp16' and ta_n == 1): + m_packed_fp16_to_bf16 = macro_packed_fp16_to_bf16_t(self.mc, num_vgpr = self.outer.get_num_vgpr_global_load_a()) + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(m_packed_fp16_to_bf16(v.v_gld_a(), v.v_tmp(5))) + self._emit(f'.endif') need_swizzle = self.outer.tunable.precision in ('fp16', 'bf16') and self.outer.tunable.tensor_b_thread_lengths[1] > 1 - self._emit(m_out_2d_shared_store(v.v_gld_a(), v.v_sst_a_os(), *(v.v_tmp(),) if need_swizzle else ())) + self._emit(m_out_2d_shared_store(v.v_gld_a(), v.v_sst_a_os(), *(v.v_tmp(),v.v_tmp(6)) if need_swizzle else ())) return self._get_deferred() class kernel_karg_t(mc_base_t): @@ -810,14 +836,16 @@ def get_macro_shared_store(self): vector_dp_b = length_dp_b class macro_swizzle_sst_t(macro_base_t): - def __init__(self, mc, t_mn): + def __init__(self, mc, t_mn, outer): macro_base_t.__init__(self, mc, True) self.issue_cnt = 0 self.t_mn = t_mn + self.outer = outer self.declare_arg("v_src") self.declare_arg("v_sst_os") if data_byte == 2: self.declare_arg("v_pack_k_tmp") # need tb_k // 2 + self.declare_arg("v_tmp2") def name(self): return '' @@ -844,9 +872,21 @@ def expr(self): for i_pk in range(packed_gemmk_dword): idx_0 = 2 * i_pk * dwords_per_mn + (i_gemmk * num_ds_write_pack + i_ds_write_pack) // 2 idx_1 = 2 * i_pk * dwords_per_mn + (i_gemmk * num_ds_write_pack + i_ds_write_pack) // 2 + dwords_per_mn - op_sel = '' if (i_gemmk * num_ds_write_pack + i_ds_write_pack) % 2 == 0 else ' op_sel:[1, 1]' - # print(f"i_pk:{i_pk}, i_c:{i_c}, idx_0:{idx_0}, idx_1:{idx_1}") - self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_ds_write_pack * 2 + i_pk)}], v[{self.v_src(idx_0)}], v[{self.v_src(idx_1)}]{op_sel}") + if self.outer.use_bf16_1k_in_fp16(): + src0_sel = '' if (i_gemmk * num_ds_write_pack + i_ds_write_pack) % 2 == 0 else ' src0_sel:WORD_1' + fp16_alt_impl_pds = self.outer.get_predefine_for_bf16_1k_in_fp16() + self._emit(f'.if {fp16_alt_impl_pds} == 1') + self._emit(f"v_cvt_f32_f16 v[{self.v_tmp2(0)}], v[{self.v_src(idx_0)}]{src0_sel}") + self._emit(f"v_cvt_f32_f16 v[{self.v_tmp2(1)}], v[{self.v_src(idx_1)}]{src0_sel}") + self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_ds_write_pack * 2 + i_pk)}], v[{self.v_tmp2(0)}], v[{self.v_tmp2(1)}] op_sel:[1, 1]") + self._emit(f'.else') + op_sel = '' if (i_gemmk * num_ds_write_pack + i_ds_write_pack) % 2 == 0 else ' op_sel:[1, 1]' + self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_ds_write_pack * 2 + i_pk)}], v[{self.v_src(idx_0)}], v[{self.v_src(idx_1)}]{op_sel}") + self._emit(f'.endif') + else: + op_sel = '' if (i_gemmk * num_ds_write_pack + i_ds_write_pack) % 2 == 0 else ' op_sel:[1, 1]' + # print(f"i_pk:{i_pk}, i_c:{i_c}, idx_0:{idx_0}, idx_1:{idx_1}") + self._emit(f"v_pack_b32_f16 v[{self.v_pack_k_tmp(i_ds_write_pack * 2 + i_pk)}], v[{self.v_src(idx_0)}], v[{self.v_src(idx_1)}]{op_sel}") self._emit(ds_write(self.v_sst_os(), self.v_pack_k_tmp(), i_gemmk * stride_d_mn)) self.issue_cnt = self.issue_cnt + ds_write.get_issues(i_gemmk * stride_d_mn) @@ -897,8 +937,8 @@ def get_issues(self): return macro_igemm_3d_shared_store_t(self.mc, out_sst_ctrl, inline) if not self.tunable.tensor_a_pass_through else None, \ macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline) if not self.tunable.tensor_b_pass_through else None else: - return macro_swizzle_sst_t(self.mc, ta_k) if not self.tunable.tensor_a_pass_through else None, \ - macro_swizzle_sst_t(self.mc, tb_c) if not self.tunable.tensor_a_pass_through else None + return macro_swizzle_sst_t(self.mc, ta_k, self) if not self.tunable.tensor_a_pass_through else None, \ + macro_swizzle_sst_t(self.mc, tb_c, self) if not self.tunable.tensor_a_pass_through else None def get_macro_in_out_update_os(self): inline = True if self.tunable.fma_interleave else False @@ -1579,7 +1619,7 @@ def move_slice_window_a(): ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block,self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, - self.tunable.precision) + self.tunable.precision, bf16_1k_in_fp16 = self.use_bf16_1k_in_fp16()) fctrl.cxm = ctrl_xdlops_mapping fctrl.unroll_k = self.tunable.gemm_k_per_block fctrl.label_prefix = self.name() diff --git a/python/operations/mfma.py b/python/operations/mfma.py index 95fe3846..f5b05878 100644 --- a/python/operations/mfma.py +++ b/python/operations/mfma.py @@ -25,6 +25,7 @@ ################################################################################ # pylint: disable=maybe-no-member from ..codegen import * +import copy def inst_mfma_data_type_to_string(data_type): if data_type == AMDGPU_PRECISION_FP32: @@ -58,6 +59,8 @@ def __init__(self, m, n, k, data_type, cycle, num_v_a, num_v_b, num_a_c, num_blo #assert arch_config.arch == AMDGPU_ARCH_GFX908 and arch_config.use_xdlops def name(self): + if 'name' in self.options and self.options['name'] != None: + return self.options['name'] def src_datatype_string(data_type_string): if data_type_string == 'fp32': return 'f32' @@ -71,7 +74,7 @@ def src_datatype_string(data_type_string): mfma_acc_type = 'i32' if self.data_type == AMDGPU_PRECISION_INT8 else 'f32' # TODO: int8 mfma accumulate type is i32 mfma_trait = f'{self.m}x{self.n}x{self.k}' + src_datatype_string(inst_mfma_data_type_to_string(self.data_type)) mfma_inst = f'v_mfma_{mfma_acc_type}_{mfma_trait}' - if 'bf16_1k' in self.options and self.options['bf16_1k']: + if 'bf16_1k' in self.options and self.options['bf16_1k'] and self.data_type == AMDGPU_PRECISION_BF16: mfma_inst += '_1k' return mfma_inst @@ -121,6 +124,36 @@ def get_nop_count_mfma_acc_raw(self): v_mfma_f32_32x32x4bf16_1k = inst_mfma_t(32, 32, 4, AMDGPU_PRECISION_BF16, 64, 2, 2, 32, 2 , bf16_1k=True) v_mfma_f32_32x32x8bf16_1k = inst_mfma_t(32, 32, 8, AMDGPU_PRECISION_BF16, 64, 2, 2, 16, 1 , bf16_1k=True) +v_mfma_f32_4x4x4_16f_m = inst_mfma_t(4, 4, 4, AMDGPU_PRECISION_BF16, 8, 2, 2, 4, 16, bf16_1k=True, name='v_mfma_f32_4x4x4_16f_m') +v_mfma_f32_16x16x4_16f_m = inst_mfma_t(16, 16, 4, AMDGPU_PRECISION_BF16, 32, 2, 2, 16, 4 , bf16_1k=True, name='v_mfma_f32_16x16x4_16f_m') +v_mfma_f32_16x16x16_16f_m = inst_mfma_t(16, 16, 16, AMDGPU_PRECISION_BF16, 32, 2, 2, 4, 1 , bf16_1k=True, name='v_mfma_f32_16x16x16_16f_m') +v_mfma_f32_32x32x4_16f_m = inst_mfma_t(32, 32, 4, AMDGPU_PRECISION_BF16, 64, 2, 2, 32, 2 , bf16_1k=True, name='v_mfma_f32_32x32x4_16f_m') +v_mfma_f32_32x32x8_16f_m = inst_mfma_t(32, 32, 8, AMDGPU_PRECISION_BF16, 64, 2, 2, 16, 1 , bf16_1k=True, name='v_mfma_f32_32x32x8_16f_m') + +def inst_mfma_emit_macro_mfma_16f(mc, predefined_symbol_bf16_enable, default_value): + mc.emit(f'.ifndef {predefined_symbol_bf16_enable}') + mc.emit(f'.set {predefined_symbol_bf16_enable}, {default_value}') + mc.emit(f'.endif') + mc.emit_empty_line() + + the_list = [v_mfma_f32_4x4x4_16f_m, v_mfma_f32_16x16x4_16f_m, v_mfma_f32_16x16x16_16f_m, v_mfma_f32_32x32x4_16f_m, v_mfma_f32_32x32x8_16f_m] + + for inst in the_list: + inst_16f = copy.deepcopy(inst) + inst_16f.options['name'] = None + # print(f'{inst.options}') + inst_16f.data_type = AMDGPU_PRECISION_BF16 + macro_name = inst.options['name'] + mc.emit(f'.macro {macro_name} d, a, b, c') + mc.emit(f'.if {predefined_symbol_bf16_enable} == 1') + mc.emit(f' {inst_16f.name()} \\d, \\a, \\b, \\c') + mc.emit(f'.else') + inst_16f.data_type = AMDGPU_PRECISION_FP16 + mc.emit(f' {inst_16f.name()} \\d, \\a, \\b, \\c') + mc.emit(f'.endif') + mc.emit(f'.endm') + mc.emit_empty_line() + # class inst_composed_mfma_t(object): # ''' # handy class to issue several mfma to form a wave wise mxn diff --git a/python/operations/mfma_main_loop.py b/python/operations/mfma_main_loop.py index 98b99e19..18079210 100755 --- a/python/operations/mfma_main_loop.py +++ b/python/operations/mfma_main_loop.py @@ -87,6 +87,8 @@ def __init__(self): self.pass_through_b_v_pack = 1 self.pass_through_a_interleave_gld = 1 self.pass_through_b_interleave_gld = 1 + self.pass_through_bf16_1k_in_fp16 = False # the pass through side is indeed bf16 1k + self.pass_through_bf16_1k_in_fp16_predefine = None # predefine symbol for .if....else self.opt_1st_sld = True # optimize 1st ds_read class mfma_main_loop_t(mc_base_t): @@ -469,7 +471,16 @@ def do_sld_q(i_v, i_r): if not p_interleave_gld and v_gld_p_gpf: # move buffer for i_pnum in range(v_gld_p_num): - self._emit(f"v_mov_b32 v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}]") + if ctrl.pass_through_bf16_1k_in_fp16: + self._emit(f".if {ctrl.pass_through_bf16_1k_in_fp16_predefine} == 1") + self._emit(f"v_cvt_f32_f16 v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}]") + self._emit(f"v_cvt_f32_f16 v[{v_gld_p_gpf(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}] src0_sel:WORD_1") + self._emit(f"v_pack_b32_f16 v[{v_gld_p(i_pnum)}], v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}] op_sel:[1,1]") + self._emit(f".else") + self._emit(f"v_mov_b32 v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}]") + self._emit(f".endif") + else: + self._emit(f"v_mov_b32 v[{v_gld_p(i_pnum)}], v[{v_gld_p_gpf(i_pnum)}]") for i_v in range(v_pack_p_per_kpt): self._emit(mfma_step_pxq_vk(i_k, i_rp, i_rq, i_v, i_local_buffer_q)) diff --git a/python/operations/utility.py b/python/operations/utility.py index 074e08e5..0dcc7ffa 100644 --- a/python/operations/utility.py +++ b/python/operations/utility.py @@ -314,6 +314,23 @@ def __call__(self, step = 0, alignment = 0): def get(self): return self.cnt +class macro_packed_fp16_to_bf16_t(macro_base_t): + def __init__(self, mc, **options): + macro_base_t.__init__(self, mc, True) + self.options = options + self.declare_arg("v_packed_f16") + self.declare_arg("v_tmp") + assert 'num_vgpr' in options + + def name(self): + return '.v_packed_fp16_to_bf16' + + def expr(self): + num_vgpr = self.options["num_vgpr"] + for i in range(num_vgpr): + self._emit(f"v_cvt_f32_f16 v[{self.v_tmp()}], v[{self.v_packed_f16(i)}]") + self._emit(f"v_cvt_f32_f16 v[{self.v_packed_f16(i)}], v[{self.v_packed_f16(i)}] src0_sel:WORD_1") + self._emit(f"v_pack_b32_f16 v[{self.v_packed_f16(i)}], v[{self.v_tmp()}], v[{self.v_packed_f16(i)}] op_sel:[1,1]") def utility_list_to_string(arr): assert type(arr) is list diff --git a/python/operations/xdlops_mapping.py b/python/operations/xdlops_mapping.py index fb37b571..4bd91497 100755 --- a/python/operations/xdlops_mapping.py +++ b/python/operations/xdlops_mapping.py @@ -470,6 +470,25 @@ def fp16_mfma_to_bf16_1k(fp16_mfma): item.wave_repeat_m, item.wave_repeat_n, item.wave_step_m, item.wave_step_n, fp16_mfma_to_bf16_1k(item.inst_mfma)) for item in ctrl_xdlops_mapping_fp16 ] +def fp16_mfma_to_16f(fp16_mfma): + if fp16_mfma.name() == 'v_mfma_f32_4x4x4f16': + return v_mfma_f32_4x4x4_16f_m + if fp16_mfma.name() == 'v_mfma_f32_16x16x4f16': + return v_mfma_f32_16x16x4_16f_m + if fp16_mfma.name() == 'v_mfma_f32_16x16x16f16': + return v_mfma_f32_16x16x16_16f_m + if fp16_mfma.name() == 'v_mfma_f32_32x32x4f16': + return v_mfma_f32_32x32x4_16f_m + if fp16_mfma.name() == 'v_mfma_f32_32x32x8f16': + return v_mfma_f32_32x32x8_16f_m + assert False, 'no such fp16 inst ' + fp16_mfma.name() + return None + +ctrl_xdlops_mapping_16f = [ctrl_xdlops_mapping_t(item.macro_tile_m, item.macro_tile_n, + item.wave_tile_m, item.wave_tile_n, item.wave_tile_k, item.waves, + item.wave_repeat_m, item.wave_repeat_n, item.wave_step_m, item.wave_step_n, + fp16_mfma_to_16f(item.inst_mfma)) for item in ctrl_xdlops_mapping_fp16 ] + ctrl_xdlops_mapping_int8 = [ ctrl_xdlops_mapping_t( 256, 256, 64, 32, 4, 4, 2, 2, 1, 2, v_mfma_i32_32x32x4i8), ctrl_xdlops_mapping_t( 256, 256, 32, 32, 8, 4, 2, 2, 2, 2, v_mfma_i32_32x32x8i8), @@ -544,8 +563,8 @@ def get_ctrl_xdlops_mapping_from_wave_tile(macro_tile_m, macro_tile_n, wave_tile if precision == AMDGPU_PRECISION_FP32: ctrl_xdlops_mapping = ctrl_xdlops_mapping_fp32 elif precision == AMDGPU_PRECISION_FP16: - if 'bf16_1k' in options and options['bf16_1k']: - ctrl_xdlops_mapping = ctrl_xdlops_mapping_bf16_1k + if 'bf16_1k_in_fp16' in options and options['bf16_1k_in_fp16']: + ctrl_xdlops_mapping = ctrl_xdlops_mapping_16f else: ctrl_xdlops_mapping = ctrl_xdlops_mapping_fp16 elif precision == AMDGPU_PRECISION_INT8: @@ -576,7 +595,7 @@ def set_ctrl_xdlops_mapping_accvgpr_unified(accvgpr_unified): if set_ctrl_xdlops_mapping_accvgpr_unified.cached_accvgpr_unified == accvgpr_unified: return set_ctrl_xdlops_mapping_accvgpr_unified.cached_accvgpr_unified = accvgpr_unified - for ctrl in (ctrl_xdlops_mapping_fp32, ctrl_xdlops_mapping_fp16, ctrl_xdlops_mapping_int8, ctrl_xdlops_mapping_bf16_1k): + for ctrl in (ctrl_xdlops_mapping_fp32, ctrl_xdlops_mapping_fp16, ctrl_xdlops_mapping_int8, ctrl_xdlops_mapping_bf16_1k, ctrl_xdlops_mapping_16f): for x in ctrl: x.inst_mfma.accvgpr_unified = accvgpr_unified