Skip to content

Commit

Permalink
fix combine mlir for llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
PhaneeshB committed Nov 20, 2023
1 parent 80a33d4 commit a53e714
Showing 1 changed file with 40 additions and 34 deletions.
74 changes: 40 additions & 34 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def combine_mlir_scripts(
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
constants_1 = set()
constants_2 = set()
f1 = []
f2 = []

Expand All @@ -255,7 +256,7 @@ def combine_mlir_scripts(
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
constants_1.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
Expand All @@ -281,7 +282,7 @@ def combine_mlir_scripts(
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
constants_2.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
Expand All @@ -304,15 +305,21 @@ def combine_mlir_scripts(
module_end = "}"

global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
global_var_loading1 = dict()
global_var_loading2 = dict()

print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
# in both 1 and 2
constants = [(e , "") for e in list(constants_1 & constants_2)]
# only in 1
constants.extend([(e, "_1") for e in list(constants_1.difference(constants_2))])
# only in 2
constants.extend([(e, "_2") for e in list(constants_2.difference(constants_1))])
del constants_1, constants_2
gc.collect()

while constants:
constant = constants.pop(0)
constant, vname_suf = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
Expand All @@ -322,43 +329,42 @@ def combine_mlir_scripts(
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
f"ml_program.global private @{vname}{vname_suf}({vbody}) : {fixed_vdtype}"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
f"ml_program.global private @{vname}{vname_suf}({vbody}) : i1"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
] = ""

del constants
gc.collect()


new_f1, new_f2 = [], []

print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
for global_var in global_var_loading1.keys():
new_f1.append(global_var)
else:
new_f1.append(line)
Expand All @@ -367,7 +373,7 @@ def combine_mlir_scripts(
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
for global_var in global_var_loading2.keys():
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
Expand Down

0 comments on commit a53e714

Please sign in to comment.