Skip to content

Commit

Permalink
Merge pull request #121 from ROCmSoftwarePlatform/gfx90a_alt_fp16_impl
Browse files Browse the repository at this point in the history
gfx90a alt fp16 impl
  • Loading branch information
carlushuang committed Oct 15, 2021
2 parents 26ee7b0 + 405351d commit 92dd200
Show file tree
Hide file tree
Showing 10 changed files with 325 additions and 40 deletions.
103 changes: 93 additions & 10 deletions python/codegen/mbb.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
MC_INST_TYPE_GLOBAL_MEM = 3
MC_INST_TYPE_LEGACY_MACRO = 4 # like macro_c_clear_t. this is a hack
MC_INST_TYPE_COMMENTS = 5
MC_INST_TYPE_OTHER = 6
MC_INST_TYPE_PREDEFINE_IF = 6
MC_INST_TYPE_PREDEFINE_ENDIF = 7
MC_INST_TYPE_OTHER = 8

def get_mc_inst_op(inst_str):
istr = inst_str.strip()
Expand All @@ -59,6 +61,10 @@ def mc_inst_is_global_mem(inst_op):
return _check_inst_prefix(inst_op, ['global_', 'buffer_'])
def mc_inst_is_legacy_macro(inst_op):
return _check_inst_prefix(inst_op, ['.v_clear_nc'])
def mc_inst_is_predefine_if(inst_op):
return _check_inst_prefix(inst_op, ['.if'])
def mc_inst_is_predefine_endif(inst_op):
return _check_inst_prefix(inst_op, ['.endif'])

def get_mc_inst_type(inst_str):
'''
Expand All @@ -76,6 +82,10 @@ def get_mc_inst_type(inst_str):
return MC_INST_TYPE_GLOBAL_MEM
if mc_inst_is_legacy_macro(inst_op):
return MC_INST_TYPE_LEGACY_MACRO
if mc_inst_is_predefine_if(inst_op):
return MC_INST_TYPE_PREDEFINE_IF
if mc_inst_is_predefine_endif(inst_op):
return MC_INST_TYPE_PREDEFINE_ENDIF
return MC_INST_TYPE_OTHER

class mc_inst_t(object):
Expand Down Expand Up @@ -105,8 +115,19 @@ def create_mc_inst(inst_str):
if mc_inst_is_legacy_macro(get_mc_inst_op(istr)):
return mc_inst_t(inst_str)

def inst_in_directive_white_list(inst):
# TODO: with the .if .. .else, this should group into a single mbb
if istr[0] != '.':
return False
asm_directive_white_list = ['.if', '.ifdef', '.else', '.endif']
for itm in asm_directive_white_list:
if inst.startswith(itm):
return True
return False

if istr[0] in (';', '/', '.', '\n'): # ignore comment, directive like .set, .macro
return None
if not inst_in_directive_white_list(istr):
return None
# print(f'[XX] {istr[0]}, {inst_str}')
return mc_inst_t(inst_str)

Expand Down Expand Up @@ -172,10 +193,14 @@ def create_machine_basic_block(multi_line_inst_str, **option):
class parse_mbb_list_t(object):
STATE_NORMAL = 0
STATE_PARSING_MBB = 1
STATE_PARSING_MBB_IN_PREDEFINE = 2

INST_MBB_START = ['v_cmpx']
INST_MBB_END = ['s_mov_b64 exec', 's_or_b64 exec']

INST_MBB_START_PREDEFINE = ['.if']
INST_MBB_END_PREDEFINE = ['.endif']

def is_mbb_start_macro_c_clear(self, current_index, istrs_list):
'''
special rule for macro_c_clear_t
Expand Down Expand Up @@ -242,6 +267,22 @@ def is_mbb_end(self, istr):
return True
return False

def is_mbb_start_predefine(self, istr):
_istr = istr.strip()
_istr = re.sub(' +', ' ', _istr) # remove multiple space character
for ms in self.INST_MBB_START_PREDEFINE:
if _istr.startswith(ms):
return True
return False

def is_mbb_end_predefine(self, istr):
_istr = istr.strip()
_istr = re.sub(' +', ' ', _istr) # remove multiple space character
for ms in self.INST_MBB_END_PREDEFINE:
if _istr.startswith(ms):
return True
return False

def parse(self, multi_line_inst_str, **option):
def get_dict_with_default(dictionary, key, default_value):
if key in dictionary:
Expand All @@ -262,6 +303,18 @@ def match_group_mbb_by_end_of_inst_op(inst_op):
return True
return False
assert False

def match_group_mbb_by_end_of_inst_op_lookback(current_index, istrs_list):
if current_index <= 0:
return False
for prev_index in range(current_index - 1, -1, -1):
prev_istr = istrs_list[prev_index]
prev_mc_inst = create_mc_inst(prev_istr)
prev_inst_op = get_mc_inst_op(prev_istr)
if not prev_inst_op:
continue
return match_group_mbb_by_end_of_inst_op(prev_inst_op)
return False # nonthing to search

istrs = multi_line_inst_str.split('\n')
mbbs = list()
Expand Down Expand Up @@ -295,28 +348,57 @@ def match_group_mbb_by_end_of_inst_op(inst_op):
if group_mbb_by_end_of_inst_op:
inst_op = get_mc_inst_op(istr)
if state == self.STATE_NORMAL:
if match_group_mbb_by_end_of_inst_op(inst_op):
mbbs.append(machine_basic_block_t(copy.copy([mc_inst])))
if self.is_mbb_start_predefine(istr):
mc_inst_buffer.append(mc_inst)
state = self.STATE_PARSING_MBB_IN_PREDEFINE
elif match_group_mbb_by_end_of_inst_op(inst_op):
mc_inst_buffer.append(mc_inst)
mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer)))
# for yy in mc_inst_buffer:
# print(f" +++inst:{yy()}")
# print(f" +++--------------------")
mc_inst_buffer.clear()
else:
#mc_inst_buffer.clear()
mc_inst_buffer.append(mc_inst)
state = self.STATE_PARSING_MBB
else:
if match_group_mbb_by_end_of_inst_op(inst_op):
elif state == self.STATE_PARSING_MBB:
if self.is_mbb_start_predefine(istr):
mc_inst_buffer.append(mc_inst)
state = self.STATE_PARSING_MBB_IN_PREDEFINE
elif match_group_mbb_by_end_of_inst_op(inst_op):
mc_inst_buffer.append(mc_inst)
mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer)))
# print(f'xxxxx_ {mc_inst_buffer}, len:{len(mc_inst_buffer)}')
# for yy in mc_inst_buffer:
# print(f" inst:{yy()}")
# machine_basic_block_t(copy.copy(mc_inst_buffer)).dump()
state = self.STATE_NORMAL
#for yy in mc_inst_buffer:
# print(f" +++inst:{yy()}")
#print(f" +++--------------------")
mc_inst_buffer.clear()
else:
mc_inst_buffer.append(mc_inst)
elif state == self.STATE_PARSING_MBB_IN_PREDEFINE:
if self.is_mbb_end_predefine(istr):
'''
only switch back, but not cut the mc_inst_buffer into mbb here
'''
mc_inst_buffer.append(mc_inst)
if match_group_mbb_by_end_of_inst_op_lookback(i, istrs):
mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer)))
#for yy in mc_inst_buffer:
# print(f" +++inst:{yy()}")
#print(f" +++--------------------")
mc_inst_buffer.clear()
state = self.STATE_NORMAL
else:
mc_inst_buffer.append(mc_inst)
else:
assert False
else:
if state == self.STATE_NORMAL:
if self.is_mbb_start(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs) \
if self.is_mbb_start(istr) or self.is_mbb_start_predefine(istr) or self.is_mbb_start_cmp_and_exec_block(i, istrs) \
or self.is_mbb_start_bfe_and_cmpx_block(i, istrs):
mc_inst_buffer.clear()
mc_inst_buffer.append(mc_inst)
Expand All @@ -331,7 +413,7 @@ def match_group_mbb_by_end_of_inst_op(inst_op):
pass
else:
assert False, f'not support recursive start/end for now, with {i}:{istr}, {istrs}'
if self.is_mbb_end(istr):
if self.is_mbb_end(istr) or self.is_mbb_end_predefine(istr):
mc_inst_buffer.append(mc_inst)
mbbs.append(machine_basic_block_t(copy.copy(mc_inst_buffer)))
state = self.STATE_NORMAL
Expand All @@ -347,9 +429,10 @@ def match_group_mbb_by_end_of_inst_op(inst_op):
#assert len(mbbs) != 0, f"nonthing parsed from input inst: {multi_line_inst_str}"
if len(mbbs) == 0:
return list() # silently return empty list
#print('************************')
#for y in mbbs:
# y.dump()
#print('++++++++++++++++++++++++++++')
#print('************************')
if dup_inst_per_mbb != "off":
_dup_str = dup_inst_per_mbb.split(',')
assert len(_dup_str) == 2
Expand Down
14 changes: 11 additions & 3 deletions python/codegen/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,23 @@ def emit_current_mbb1(c_mbb1):
return self._get_deferred()

if interleave_pattern == INTERLEAVE_PTN_1:
def check_share_mem_block(current_mbb):
if current_mbb.mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM:
return True
if current_mbb.length() >=2:
if current_mbb.mc_inst(-2).type() == MC_INST_TYPE_SHARE_MEM and \
current_mbb.mc_inst(-1).type() == MC_INST_TYPE_PREDEFINE_ENDIF:
return True
return False
mbb_0_mfma_cnt = 0
for m in mbb_0:
if mbb_have_mfma(m):
mbb_0_mfma_cnt += 1

assert mbb_1[0].mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM
assert check_share_mem_block(mbb_1[0])
num_smem = 0
for i in range(len(mbb_1)):
if mbb_1[i].mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM:
if check_share_mem_block(mbb_1[i]):
num_smem = num_smem + 1
else:
pass
Expand All @@ -353,7 +361,7 @@ def emit_current_mbb1(c_mbb1):
break
# print(f' --- inst:{mbb_1[m1_idx]()} === {m1_idx}/{len(mbb_1)}, {smem_per_interleave_cnt}/smem_per_interleave:{smem_per_interleave}')
self._emit(self.call_mbb(mbb_1[m1_idx]))
if mbb_1[m1_idx].mc_inst(-1).type() == MC_INST_TYPE_SHARE_MEM:
if check_share_mem_block(mbb_1[m1_idx]):
smem_per_interleave_cnt = smem_per_interleave_cnt + 1
m1_idx += 1
if smem_per_interleave_cnt >= smem_per_interleave:
Expand Down
5 changes: 5 additions & 0 deletions python/codegen_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ def emit_global_macro(self):
macro_c_clear_t(self.mc).emit()
if self.mc.arch_config.use_dlops:
self._emit_fma_macro()
if hasattr(self.kernel_list[0], 'use_bf16_1k_in_fp16'):
if self.kernel_list[0].use_bf16_1k_in_fp16():
sym = self.kernel_list[0].get_predefine_for_bf16_1k_in_fp16()
dfv = self.kernel_list[0].get_predefine_for_bf16_1k_in_fp16_default_value()
inst_mfma_emit_macro_mfma_16f(self.mc, sym, dfv)

def emit_global_macro_per_s_file(self, mc):
# emit global macro, independent of tunable
Expand Down
Loading

0 comments on commit 92dd200

Please sign in to comment.