Skip to content

Commit

Permalink
Add int4/int8 vicuna
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Jun 26, 2023
1 parent 75672c0 commit 2523500
Show file tree
Hide file tree
Showing 4 changed files with 336 additions and 66 deletions.
47 changes: 45 additions & 2 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,37 @@
import torch
from transformers import AutoModelForCausalLM

from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl


class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(self, model_path, precision):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_type="float32",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float32",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048)
print("Weight quantization applied.")

def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
Expand All @@ -22,12 +45,32 @@ def forward(self, input_ids):


class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(self, model_path, precision):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_type="float32",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float32",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048)
print("Weight quantization applied.")

def forward(
self,
Expand Down
84 changes: 64 additions & 20 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import re
import torch
import torch_mlir
import os
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from brevitas_examples.llm.llm_quant.mlir_custom_mm import brevitas_matmul_rhs_group_quant_library


class Vicuna(SharkLLMBase):
Expand All @@ -38,9 +39,9 @@ def __init__(
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if precision in ["int4", "int8"]:
print("int4 and int8 are not supported yet, using fp32")
precision = "fp32"
# if precision in ["int4", "int8"]:
# print("int4 and int8 are not supported yet, using fp32")
# precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
Expand Down Expand Up @@ -133,13 +134,14 @@ def compile_first_vicuna(self):
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
model = FirstVicuna(self.hf_model_path, self.precision)

print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
Expand All @@ -151,13 +153,28 @@ def compile_first_vicuna(self):
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
if self.precision in ["int4", "int8"]:
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
else:
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph

def remove_constant_dim(line):
Expand All @@ -180,9 +197,15 @@ def remove_constant_dim(line):

module = str(module)
new_lines = []
test_lines = []

print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
if len(line) < 1000:
test_lines.append(line)
else:
test_lines.append(line[:999])

line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
Expand All @@ -198,6 +221,11 @@ def remove_constant_dim(line):

module = "\n".join(new_lines)

test_module = "\n".join(test_lines)
f1_ = open(f"first_vicuna_test.mlir", "w+")
f1_.write(test_module)
f1_.close()

print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
Expand Down Expand Up @@ -273,11 +301,12 @@ def compile_second_vicuna(self):
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
model = SecondVicuna(self.hf_model_path, self.precision)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",
)
Expand All @@ -296,13 +325,28 @@ def compile_second_vicuna(self):
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
if self.precision in ["int4", "int8"]:
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
else:
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)

def remove_constant_dim(line):
if "c19_i64" in line:
Expand Down
130 changes: 105 additions & 25 deletions apps/language_models/src/pipelines/vicuna_sharded_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@
import json


from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
from brevitas_examples.llm.llm_quant.mlir_custom_mm import brevitas_matmul_rhs_group_quant_library


class Vicuna(SharkLLMBase):
# Class representing Sharded Vicuna Model
def __init__(
Expand Down Expand Up @@ -136,6 +142,7 @@ def compile_vicuna_layer(
vicuna_layer,
model_inputs,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
Expand Down Expand Up @@ -326,17 +333,36 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
ts_g = self.compile_vicuna_layer(
layer, inputs[0], inputs[1], inputs[2]
)
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
if self.precision in ["int4", "int8"]:
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
else:
module = torch_mlir.compile(
ts_g,
(
hidden_states_placeholder,
inputs[1],
inputs[2],
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
else:
ts_g = self.compile_vicuna_layer(
layer,
Expand All @@ -346,29 +372,62 @@ def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True):
inputs[3],
inputs[4],
)
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
if self.precision in ["int4", "int8"]:
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
output_type="torch",
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module, "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR")
else:
module = torch_mlir.compile(
ts_g,
(
inputs[0],
attention_mask_placeholder,
inputs[2],
pkv0_placeholder,
pkv1_placeholder,
),
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)

if is_first:
module = self.write_in_dynamic_inputs0(str(module), 137)
if idx in [0, 5, 6, 7]:
module_str = module
module_str = module_str.splitlines()
new_lines = []
for line in module_str:
if len(line) < 1000:
new_lines.append(line)
else:
new_lines.append(line[:999])
module_str = "\n".join(new_lines)
f1_ = open(f"{idx}_0_test.mlir", "w+")
f1_.write(module_str)
f1_.close()
bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()

else:
module = self.write_in_dynamic_inputs1(str(module), 138)

bytecode = module.encode("UTF-8")
bytecode_stream = BytesIO(bytecode)
bytecode = bytecode_stream.read()
Expand Down Expand Up @@ -458,6 +517,27 @@ def get_sharded_model(self, device="cpu"):
# please don't change it
SAMPLE_INPUT_LEN = 137
vicuna_model = self.get_src_model()

if self.precision in ["int4", "int8"]:
print("Applying weight quantization..")
quantize_model(
get_model_impl(vicuna_model).layers,
weight_quant_type="asym",
weight_bit_width=8,
weight_param_method="stats",
weight_scale_type="float32",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float32",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048)
print("Weight quantization applied.")

placeholder_input0 = (
torch.zeros([1, SAMPLE_INPUT_LEN, 4096]),
torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]),
Expand Down
Loading

0 comments on commit 2523500

Please sign in to comment.