Skip to content

Commit

Permalink
Add int4/int8 vicuna
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Jun 29, 2023
1 parent 534de05 commit 3fadf17
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 63 deletions.
49 changes: 47 additions & 2 deletions apps/language_models/src/model_wrappers/vicuna_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
import torch
from transformers import AutoModelForCausalLM

from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl


class FirstVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(self, model_path, precision):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("First Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_type="float32",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float32",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")

def forward(self, input_ids):
op = self.model(input_ids=input_ids, use_cache=True)
Expand All @@ -22,12 +46,33 @@ def forward(self, input_ids):


class SecondVicuna(torch.nn.Module):
def __init__(self, model_path):
def __init__(self, model_path, precision):
super().__init__()
kwargs = {"torch_dtype": torch.float32}
self.model = AutoModelForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if precision in ["int4", "int8"]:
print("Second Vicuna applying weight quantization..")
weight_bit_width = 4 if precision == "int4" else 8
quantize_model(
get_model_impl(self.model).layers,
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_type="float32",
weight_quant_granularity="per_group",
weight_group_size=128,
quantize_weight_zero_point=False,
input_bit_width=None,
input_scale_type="float32",
input_param_method="stats",
input_quant_type="asym",
input_quant_granularity="per_tensor",
quantize_input_zero_point=False,
seqlen=2048,
)
print("Weight quantization applied.")

def forward(
self,
Expand Down
84 changes: 67 additions & 17 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
import re
import torch
import torch_mlir
import os
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from brevitas_examples.llm.llm_quant.mlir_custom_mm import (
brevitas_matmul_rhs_group_quant_library,
)


class Vicuna(SharkLLMBase):
Expand Down Expand Up @@ -135,13 +138,14 @@ def compile_first_vicuna(self):
compilation_input_ids
).reshape([1, 19])
firstVicunaCompileInput = (compilation_input_ids,)
model = FirstVicuna(self.hf_model_path)
model = FirstVicuna(self.hf_model_path, self.precision)

print(f"[DEBUG] generating torchscript graph")
ts_graph = import_with_fx(
model,
firstVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False, False],
mlir_type="torchscript",
)
Expand All @@ -153,13 +157,30 @@ def compile_first_vicuna(self):
firstVicunaCompileInput[0], dynamic_axes=[1]
)
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
if self.precision in ["int4", "int8"]:
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*firstVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
del ts_graph

def remove_constant_dim(line):
Expand All @@ -182,9 +203,15 @@ def remove_constant_dim(line):

module = str(module)
new_lines = []
test_lines = []

print(f"[DEBUG] rewriting torch_mlir file")
for line in module.splitlines():
if len(line) < 1000:
test_lines.append(line)
else:
test_lines.append(line[:999])

line = remove_constant_dim(line)
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
new_lines.append(
Expand All @@ -200,6 +227,11 @@ def remove_constant_dim(line):

module = "\n".join(new_lines)

test_module = "\n".join(test_lines)
f1_ = open(f"first_vicuna_test.mlir", "w+")
f1_.write(test_module)
f1_.close()

print(f"[DEBUG] converting to bytecode")
del new_lines
module = module.encode("UTF-8")
Expand Down Expand Up @@ -275,11 +307,12 @@ def compile_second_vicuna(self):
for _ in range(64)
)
secondVicunaCompileInput = (compilation_input_ids,) + pkv
model = SecondVicuna(self.hf_model_path)
model = SecondVicuna(self.hf_model_path, self.precision)
ts_graph = import_with_fx(
model,
secondVicunaCompileInput,
is_f16=self.precision == "fp16",
precision=self.precision,
f16_input_mask=[False] + [True] * 64,
mlir_type="torchscript",
)
Expand All @@ -298,13 +331,30 @@ def compile_second_vicuna(self):
secondVicunaCompileInput[i], dynamic_axes=[2]
)
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)
if self.precision in ["int4", "int8"]:
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
output_type=torch_mlir.OutputType.TORCH,
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
extra_library=brevitas_matmul_rhs_group_quant_library,
use_tracing=False,
verbose=False,
)
print(f"[DEBUG] converting torch to linalg")
run_pipeline_with_repro_report(
module,
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
)
else:
module = torch_mlir.compile(
ts_graph,
[*secondVicunaCompileInput],
torch_mlir.OutputType.LINALG_ON_TENSORS,
use_tracing=False,
verbose=False,
)

def remove_constant_dim(line):
if "c19_i64" in line:
Expand Down
Loading

0 comments on commit 3fadf17

Please sign in to comment.