Skip to content

Commit

Permalink
Llama65B patch for int4 fp32
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma committed Aug 16, 2023
1 parent 6da391c commit bd5b2b4
Show file tree
Hide file tree
Showing 3 changed files with 402 additions and 88 deletions.
185 changes: 100 additions & 85 deletions apps/language_models/scripts/vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"--model_name",
type=str,
default="vicuna",
choices=["vicuna", "llama2_7b", "llama2_70b"],
choices=["vicuna", "llama_65b", "llama2_7b", "llama2_70b"],
help="Specify which model to run.",
)
parser.add_argument(
Expand Down Expand Up @@ -161,7 +161,7 @@ class VicunaBase(SharkLLMBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_model_path="elinas/llama-65b-hf-transformers-4.29",
max_num_tokens=512,
device="cpu",
precision="int8",
Expand Down Expand Up @@ -433,7 +433,7 @@ class ShardedVicuna(VicunaBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_model_path="elinas/llama-65b-hf-transformers-4.29",
max_num_tokens=512,
device="cuda",
precision="fp32",
Expand Down Expand Up @@ -1212,7 +1212,7 @@ class UnshardedVicuna(VicunaBase):
def __init__(
self,
model_name,
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
hf_model_path="elinas/llama-65b-hf-transformers-4.29",
hf_auth_token: str = None,
max_num_tokens=512,
device="cpu",
Expand All @@ -1232,7 +1232,9 @@ def __init__(
"HF auth token required. Pass it using --hf_auth_token flag."
)
self.hf_auth_token = hf_auth_token
if self.model_name == "llama2_7b":
if self.model_name == "llama_65b":
self.hf_model_path = "elinas/llama-65b-hf-transformers-4.29"
elif self.model_name == "llama2_7b":
self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf"
elif self.model_name == "llama2_70b":
self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf"
Expand Down Expand Up @@ -1423,21 +1425,21 @@ def compile(self, download_vmfb=False):
else:
compilation_prompt = "".join(["0" for _ in range(17)])

if Path(f"first_{self.precision}.mlir").exists():
print(f"loading first_{self.precision}.mlir")
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
first_module = f.read()
if Path(f"second_{self.precision}.mlir").exists():
print(f"loading second_{self.precision}.mlir")
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
second_module = f.read()
else:
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
pkv = tuple(
(torch.zeros([1, 64, 19, 128], dtype=torch.float32))
for _ in range(160)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
Expand All @@ -1447,27 +1449,33 @@ def compile(self, download_vmfb=False):
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False, False],
f16_input_mask=[False] + [True] * 160,
mlir_type="torchscript",
)
del model
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)

firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 160,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
first_module = torch_mlir.compile(
second_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=[
"brevitas.matmul_rhs_group_quant"
Expand All @@ -1478,47 +1486,58 @@ def compile(self, download_vmfb=False):
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
first_module,
second_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
first_module = torch_mlir.compile(
second_module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
from contextlib import redirect_stdout
print("Writing : second_llama_65b_linalg_ir_before_dynamic ELIDED")
with open('second_llama_65b_linalg_ir_before_dynamic_elided.mlir', 'w') as f:
with redirect_stdout(f):
print(second_module.operation.get_asm(large_elements_limit=4))
print("FINISHED")
del ts_graph
del firstVicunaCompileInput
del secondVicunaCompileInput
gc.collect()

print(
"[DEBUG] successfully generated first vicuna linalg mlir"
"[DEBUG] successfully generated second vicuna linalg mlir"
)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
second_module = self.write_in_dynamic_inputs1(
str(second_module)
)
if self.cache_vicunas:
with open(f"first_{self.precision}.mlir", "w+") as f:
f.write(first_module)
print("Writing : second_llama_65b_linalg_ir_after_dynamic ELIDED")
with open('second_llama_65b_linalg_ir_after_dynamic_elided.mlir', 'w') as f:
with redirect_stdout(f):
print(second_module.operation.get_asm(large_elements_limit=4))
print("FINISHED")
# if self.cache_vicunas:
print("Writing : second_llama_65b_linalg_ir_after_dynamic")
with open(f"second_{self.precision}.mlir", "w+") as f:
f.write(second_module)

if Path(f"second_{self.precision}.mlir").exists():
print(f"loading second_{self.precision}.mlir")
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
second_module = f.read()
if Path(f"first_{self.precision}.mlir").exists():
print(f"loading first_{self.precision}.mlir")
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
first_module = f.read()
else:
# generate second vicuna
compilation_input_ids = torch.zeros(
[1, 1], dtype=torch.int64
)
pkv = tuple(
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(
# generate first vicuna
compilation_input_ids = self.tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(
self.hf_model_path,
self.precision,
self.weight_group_size,
Expand All @@ -1528,33 +1547,27 @@ def compile(self, download_vmfb=False):
print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
del model
if self.precision == "fp16":
secondVicunaCompileInput = get_f16_inputs(
secondVicunaCompileInput,
True,
f16_input_mask=[False] + [True] * 64,
)
secondVicunaCompileInput = list(secondVicunaCompileInput)
for i in range(len(secondVicunaCompileInput)):
if i != 0:
secondVicunaCompileInput[
i
] = torch_mlir.TensorPlaceholder.like(
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
firstVicunaCompileInput = list(firstVicunaCompileInput)
firstVicunaCompileInput[
0
] = torch_mlir.TensorPlaceholder.like(
firstVicunaCompileInput[0], dynamic_axes=[1]
)

firstVicunaCompileInput = tuple(firstVicunaCompileInput)
first_module = None
print(f"[DEBUG] generating torch mlir")
if self.precision in ["int4", "int8"]:
second_module = torch_mlir.compile(
first_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=[
"brevitas.matmul_rhs_group_quant"
Expand All @@ -1565,30 +1578,31 @@ def compile(self, download_vmfb=False):
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
second_module,
first_module,
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
second_module = torch_mlir.compile(
first_module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph
del secondVicunaCompileInput
del firstVicunaCompileInput
gc.collect()

print(
"[DEBUG] successfully generated second vicuna linalg mlir"
"[DEBUG] successfully generated first vicuna linalg mlir"
)
second_module = self.write_in_dynamic_inputs1(
str(second_module)
first_module = self.write_in_dynamic_inputs0(
str(first_module), dynamic_input_size=19
)
if self.cache_vicunas:
with open(f"second_{self.precision}.mlir", "w+") as f:
f.write(second_module)
with open(f"first_{self.precision}.mlir", "w+") as f:
f.write(first_module)

combined_module = self.combine_mlir_scripts(
first_module, second_module, self.vicuna_mlir_path
Expand Down Expand Up @@ -1752,6 +1766,7 @@ def autocomplete(self, prompt):

model_list = {
"vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF",
"llama_65b": "elinas/llama-65b-hf-transformers-4.29",
"llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf",
"llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf",
}
Expand Down
Loading

0 comments on commit bd5b2b4

Please sign in to comment.