Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama2 - FP16 - IR generation #1946

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 40 additions & 34 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ def combine_mlir_scripts(
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
constants_1 = set()
constants_2 = set()
f1 = []
f2 = []

Expand All @@ -255,7 +256,7 @@ def combine_mlir_scripts(
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
constants_1.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
Expand All @@ -281,7 +282,7 @@ def combine_mlir_scripts(
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
constants_2.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
Expand All @@ -304,15 +305,21 @@ def combine_mlir_scripts(
module_end = "}"

global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
global_var_loading1 = dict()
global_var_loading2 = dict()

print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
# in both 1 and 2
constants = [(e , "") for e in list(constants_1 & constants_2)]
# only in 1
constants.extend([(e, "_1") for e in list(constants_1.difference(constants_2))])
# only in 2
constants.extend([(e, "_2") for e in list(constants_2.difference(constants_1))])
del constants_1, constants_2
gc.collect()

while constants:
constant = constants.pop(0)
constant, vname_suf = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
Expand All @@ -322,43 +329,42 @@ def combine_mlir_scripts(
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
f"ml_program.global private @{vname}{vname_suf}({vbody}) : {fixed_vdtype}"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
f"ml_program.global private @{vname}{vname_suf}({vbody}) : i1"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
] = ""

del constants
gc.collect()


new_f1, new_f2 = [], []

print(f"[DEBUG] processing f1")
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
for global_var in global_var_loading1.keys():
new_f1.append(global_var)
else:
new_f1.append(line)
Expand All @@ -367,7 +373,7 @@ def combine_mlir_scripts(
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
for global_var in global_var_loading2.keys():
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
Expand Down
Loading