From 7c1ce66c6b89be9e1e91f037d216f86aefdf8f92 Mon Sep 17 00:00:00 2001 From: PhaneeshB Date: Fri, 11 Aug 2023 01:40:32 +0530 Subject: [PATCH] fix combine mlir for llama2 --- apps/language_models/scripts/vicuna.py | 74 ++++++++++++++------------ 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 00ad45d4c7..e4d0f1b79d 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -244,7 +244,8 @@ def combine_mlir_scripts( print(f"[DEBUG] output_name = {output_name}") maps1 = [] maps2 = [] - constants = set() + constants_1 = set() + constants_2 = set() f1 = [] f2 = [] @@ -255,7 +256,7 @@ def combine_mlir_scripts( if re.search("#map\d*\s*=", line): maps1.append(line) elif re.search("arith.constant", line): - constants.add(line) + constants_1.add(line) elif not re.search("module", line): line = re.sub("forward", "first_vicuna_forward", line) f1.append(line) @@ -281,7 +282,7 @@ def combine_mlir_scripts( elif "global_seed" in line: continue elif re.search("arith.constant", line): - constants.add(line) + constants_2.add(line) elif not re.search("module", line): line = re.sub("forward", "second_vicuna_forward", line) f2.append(line) @@ -304,15 +305,21 @@ def combine_mlir_scripts( module_end = "}" global_vars = [] - vnames = [] - global_var_loading1 = [] - global_var_loading2 = [] + global_var_loading1 = dict() + global_var_loading2 = dict() print(f"[DEBUG] processing constants") - counter = 0 - constants = list(constants) + # in both 1 and 2 + constants = [(e , "") for e in list(constants_1 & constants_2)] + # only in 1 + constants.extend([(e, "_1") for e in list(constants_1.difference(constants_2))]) + # only in 2 + constants.extend([(e, "_2") for e in list(constants_2.difference(constants_1))]) + del constants_1, constants_2 + gc.collect() + while constants: - constant = constants.pop(0) + constant, vname_suf = constants.pop(0) vname, vbody = constant.split("=") vname = re.sub("%", "", vname) vname = vname.strip() @@ -322,35 +329,34 @@ def combine_mlir_scripts( print(constant) vdtype = vbody.split(":")[-1].strip() fixed_vdtype = vdtype - noinline = "{noinline}" if "tensor" in fixed_vdtype else "" - if "c1_i64" in vname: - print(constant) - counter += 1 - if counter == 2: - counter = 0 - print("detected duplicate") - continue - vnames.append(vname) if "true" not in vname: global_vars.append( - f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + f"ml_program.global private @{vname}{vname_suf}({vbody}) : {fixed_vdtype}" ) + if vname_suf != "_2": + global_var_loading1[ + f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}" + ] = "" + if vname_suf != "_1": + global_var_loading2[ + f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}" + ] = "" else: global_vars.append( - f"ml_program.global private @{vname}({vbody}) : i1" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" + f"ml_program.global private @{vname}{vname_suf}({vbody}) : i1" ) + if vname_suf != "_2": + global_var_loading1[ + f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1" + ] = "" + if vname_suf != "_1": + global_var_loading2[ + f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1" + ] = "" + + del constants + gc.collect() + new_f1, new_f2 = [], [] @@ -358,7 +364,7 @@ def combine_mlir_scripts( for line in f1: if "func.func" in line: new_f1.append(line) - for global_var in global_var_loading1: + for global_var in global_var_loading1.keys(): new_f1.append(global_var) else: new_f1.append(line) @@ -367,7 +373,7 @@ def combine_mlir_scripts( for line in f2: if "func.func" in line: new_f2.append(line) - for global_var in global_var_loading2: + for global_var in global_var_loading2.keys(): if ( "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in global_var