From 29f5ca1ba365d19050b44c0beb969dd16893c767 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Thu, 15 Jun 2023 07:53:29 -0700 Subject: [PATCH] Add int4/int8 vicuna --- .gitignore | 3 + apps/language_models/scripts/vicuna.py | 1448 ++++++++++++++++- .../src/model_wrappers/vicuna_model.py | 51 +- .../src/pipelines/vicuna_pipeline.py | 612 ------- .../src/pipelines/vicuna_sharded_pipeline.py | 686 -------- requirements.txt | 3 + setup_venv.sh | 2 + shark/shark_importer.py | 154 +- 8 files changed, 1619 insertions(+), 1340 deletions(-) delete mode 100644 apps/language_models/src/pipelines/vicuna_pipeline.py delete mode 100644 apps/language_models/src/pipelines/vicuna_sharded_pipeline.py diff --git a/.gitignore b/.gitignore index eeb217e2b6..d0c67bd5e3 100644 --- a/.gitignore +++ b/.gitignore @@ -189,3 +189,6 @@ apps/stable_diffusion/web/models/ # Stencil annotators. stencil_annotator/ + +# brevitas custom op lib +apps/language_models/scripts/vicuna.py diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 54c32c8269..8d5cab1855 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1,9 +1,44 @@ import argparse +import json +import re +from io import BytesIO from pathlib import Path -from apps.language_models.src.pipelines import vicuna_pipeline as vp -from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp +from tqdm import tqdm +from typing import List, Tuple + import torch -import json +import torch_mlir +from torch_mlir import TensorPlaceholder +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from transformers import AutoTokenizer, AutoModelForCausalLM + +from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase +from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( + FirstVicunaLayer, + SecondVicunaLayer, + CompiledVicunaLayer, + ShardedVicunaModel, + LMHead, + LMHeadCompiled, + VicunaEmbedding, + VicunaEmbeddingCompiled, + VicunaNorm, + VicunaNormCompiled, +) +from apps.language_models.src.model_wrappers.vicuna_model import ( + FirstVicuna, + SecondVicuna, +) +from apps.language_models.utils import ( + get_vmfb_from_path, +) +from shark.shark_downloader import download_public_file +from shark.shark_importer import get_f16_inputs +from shark.shark_importer import import_with_fx +from shark.shark_inference import SharkInference + +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl if __name__ == "__main__": import gc @@ -13,7 +48,6 @@ prog="vicuna runner", description="runs a vicuna model", ) - parser.add_argument( "--precision", "-p", default="fp32", help="fp32, fp16, int8, int4" ) @@ -29,7 +63,6 @@ help="Run model as sharded", ) # TODO: sharded config - parser.add_argument( "--second_vicuna_vmfb_path", default=None, @@ -45,7 +78,6 @@ default=None, help="path to second vicuna mlir", ) - parser.add_argument( "--load_mlir_from_shark_tank", default=False, @@ -58,12 +90,1408 @@ action=argparse.BooleanOptionalAction, help="Run model in cli mode", ) - parser.add_argument( "--config", default=None, help="configuration file", ) +parser.add_argument( + "--weight-group-size", + type=int, + default=128, + help="Group size for per_group weight quantization. Default: 128.", +) + + +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + if len(lhs) == 3 and len(rhs) == 2: + return [lhs[0], lhs[1], rhs[0]] + elif len(lhs) == 2 and len(rhs) == 2: + return [lhs[0], rhs[0]] + else: + raise ValueError("Input shapes not supported.") + + +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: + # output dtype is the dtype of the lhs float input + lhs_rank, lhs_dtype = lhs_rank_dtype + return lhs_dtype + + +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: + return + + +brevitas_matmul_rhs_group_quant_library = [ + brevitas〇matmul_rhs_group_quant〡shape, + brevitas〇matmul_rhs_group_quant〡dtype, + brevitas〇matmul_rhs_group_quant〡has_value_semantics] + + +class ShardedVicuna(SharkLLMBase): + # Class representing Sharded Vicuna Model + def __init__( + self, + model_name, + hf_model_path="TheBloke/vicuna-7B-1.1-HF", + max_num_tokens=512, + device="cuda", + precision="fp32", + config_json=None, + weight_group_size=128, + ) -> None: + super().__init__(model_name, hf_model_path, max_num_tokens) + self.max_sequence_length = 256 + self.device = device + self.precision = precision + self.tokenizer = self.get_tokenizer() + self.config = config_json + self.weight_group_size = weight_group_size + self.shark_model = self.compile(device=device) + + def get_tokenizer(self): + # Retrieve the tokenizer from Huggingface + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, use_fast=False + ) + return tokenizer + + def get_src_model(self): + # Retrieve the torch model from Huggingface + kwargs = {"torch_dtype": torch.float} + vicuna_model = AutoModelForCausalLM.from_pretrained( + self.hf_model_path, **kwargs + ) + return vicuna_model + + def write_in_dynamic_inputs0(self, module, dynamic_input_size): + # Current solution for ensuring mlir files support dynamic inputs + # TODO find a more elegant way to implement this + new_lines = [] + for line in module.splitlines(): + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) + line = re.sub(f" {dynamic_input_size},", " %dim,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + def write_in_dynamic_inputs1(self, module, dynamic_input_size): + new_lines = [] + for line in module.splitlines(): + if "dim_42 =" in line: + continue + if f"%c{dynamic_input_size}_i64 =" in line: + new_lines.append( + "%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>" + ) + new_lines.append( + f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" + ) + continue + line = re.sub(f"{dynamic_input_size}x", "?x", line) + line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line) + if "?x" in line: + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim_42)", line + ) + line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim_42\)", + "tensor.empty(%dim_42, %dim_42)", + line, + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim_42", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + def combine_mlir_scripts( + self, first_vicuna_mlir, second_vicuna_mlir, output_name + ): + maps1 = [] + maps2 = [] + constants = set() + f1 = [] + f2 = [] + + for line in first_vicuna_mlir.splitlines(): + if re.search("#map\d*\s*=", line): + maps1.append(line) + elif re.search("arith.constant", line): + constants.add(line) + elif not re.search("module", line): + line = re.sub("forward", "first_vicuna_forward", line) + f1.append(line) + f1 = f1[:-1] + + for i, map_line in enumerate(maps1): + map_var = map_line.split(" ")[0] + map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line) + maps1[i] = map_line + f1 = [ + re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line) + for func_line in f1 + ] + + for line in second_vicuna_mlir.splitlines(): + if re.search("#map\d*\s*=", line): + maps2.append(line) + elif "global_seed" in line: + continue + elif re.search("arith.constant", line): + constants.add(line) + elif not re.search("module", line): + line = re.sub("forward", "second_vicuna_forward", line) + f2.append(line) + f2 = f2[:-1] + + for i, map_line in enumerate(maps2): + map_var = map_line.split(" ")[0] + map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line) + maps2[i] = map_line + f2 = [ + re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line) + for func_line in f2 + ] + + module_start = ( + 'module attributes {torch.debug_module_name = "_lambda"} {' + ) + module_end = "}" + + global_vars = [] + vnames = [] + vdtypes = [] + global_var_loading1 = [] + global_var_loading2 = [] + + for constant in list(constants): + vname, vbody = constant.split("=") + vname = re.sub("%", "", vname) + vname = vname.strip() + vbody = re.sub("arith.constant", "", vbody) + vbody = vbody.strip() + vdtype = vbody.split(":")[1].strip() + fixed_vdtype = vdtype + vdtypes.append(vdtype) + vdtype = re.sub("\d{1,}x", "?x", vdtype) + vnames.append(vname) + global_vars.append( + f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}" + ) + global_var_loading1.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + ) + global_var_loading2.append( + f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" + ) + + new_f1, new_f2 = [], [] + + for line in f1: + if "func.func" in line: + new_f1.append(line) + for global_var in global_var_loading1: + new_f1.append(global_var) + else: + new_f1.append(line) + + for line in f2: + if "func.func" in line: + new_f2.append(line) + for global_var in global_var_loading1: + new_f2.append(global_var) + else: + new_f2.append(line) + + f1 = new_f1 + f2 = new_f2 + + whole_string = "\n".join( + maps1 + + maps2 + + [module_start] + + global_vars + + f1 + + f2 + + [module_end] + ) + + f_ = open(output_name, "w+") + f_.write(whole_string) + f_.close() + + return whole_string + + def compile_vicuna_layer( + self, + vicuna_layer, + hidden_states, + attention_mask, + position_ids, + past_key_value0=None, + past_key_value1=None, + ): + # Compile a hidden decoder layer of vicuna + if past_key_value0 is None and past_key_value1 is None: + model_inputs = (hidden_states, attention_mask, position_ids) + else: + model_inputs = ( + hidden_states, + attention_mask, + position_ids, + past_key_value0, + past_key_value1, + ) + mlir_bytecode = import_with_fx( + vicuna_layer, + model_inputs, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + return mlir_bytecode + + def get_device_index(self, layer_string): + # Get the device index from the config file + # In the event that different device indices are assigned to + # different parts of a layer, a majority vote will be taken and + # everything will be run on the most commonly used device + if self.config is None: + return None + idx_votes = {} + for key in self.config.keys(): + if re.search(layer_string, key): + if int(self.config[key]["gpu"]) in idx_votes.keys(): + idx_votes[int(self.config[key]["gpu"])] += 1 + else: + idx_votes[int(self.config[key]["gpu"])] = 1 + device_idx = max(idx_votes, key=idx_votes.get) + return device_idx + + def compile_lmhead( + self, lmh, hidden_states, device="cpu", device_idx=None + ): + # compile the lm head of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"lmhead.mlir") + vmfb_path = Path(f"lmhead.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + lmh, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + mmap=False, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="lmhead") + shark_module.load_module(vmfb_path) + compiled_module = LMHeadCompiled(shark_module) + return compiled_module + + def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): + # compile the normalization layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"norm.mlir") + vmfb_path = Path(f"norm.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + fvn, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + mmap=False, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="norm") + shark_module.load_module(vmfb_path) + compiled_module = VicunaNormCompiled(shark_module) + return compiled_module + + def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): + # compile the embedding layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"embedding.mlir") + vmfb_path = Path(f"embedding.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + input_ids = torch_mlir.TensorPlaceholder.like( + input_ids, dynamic_axes=[1] + ) + module = torch_mlir.compile( + fve, + (input_ids,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + mmap=False, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="embedding") + shark_module.load_module(vmfb_path) + compiled_module = VicunaEmbeddingCompiled(shark_module) + + return compiled_module + + def compile_to_vmfb_one_model( + self, inputs0, layers0, inputs1, layers1, device="cpu" + ): + mlirs, modules = [], [] + assert len(layers0) == len(layers1) + for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): + mlir_path = Path(f"{idx}_full.mlir") + vmfb_path = Path(f"{idx}_full.vmfb") + # if vmfb_path.exists(): + # continue + if mlir_path.exists(): + # print(f"Found layer {idx} mlir") + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + mlirs.append(bytecode) + else: + hidden_states_placeholder0 = TensorPlaceholder.like( + inputs0[0], dynamic_axes=[1] + ) + attention_mask_placeholder0 = TensorPlaceholder.like( + inputs0[1], dynamic_axes=[3] + ) + position_ids_placeholder0 = TensorPlaceholder.like( + inputs0[2], dynamic_axes=[1] + ) + hidden_states_placeholder1 = TensorPlaceholder.like( + inputs1[0], dynamic_axes=[1] + ) + attention_mask_placeholder1 = TensorPlaceholder.like( + inputs1[1], dynamic_axes=[3] + ) + position_ids_placeholder1 = TensorPlaceholder.like( + inputs1[2], dynamic_axes=[1] + ) + pkv0_placeholder = TensorPlaceholder.like( + inputs1[3], dynamic_axes=[2] + ) + pkv1_placeholder = TensorPlaceholder.like( + inputs1[4], dynamic_axes=[2] + ) + + print(f"Compiling layer {idx} mlir") + ts_g = self.compile_vicuna_layer( + layer0, inputs0[0], inputs0[1], inputs0[2] + ) + if self.precision in ["int4", "int8"]: + module0 = torch_mlir.compile( + ts_g, + ( + hidden_states_placeholder0, + inputs0[1], + inputs0[2], + ), + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module0, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module0 = torch_mlir.compile( + ts_g, + ( + hidden_states_placeholder0, + inputs0[1], + inputs0[2], + ), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + module0 = self.write_in_dynamic_inputs0(str(module0), 137) + + ts_g = self.compile_vicuna_layer( + layer1, + inputs1[0], + inputs1[1], + inputs1[2], + inputs1[3], + inputs1[4], + ) + if self.precision in ["int4", "int8"]: + module1 = torch_mlir.compile( + ts_g, + ( + inputs1[0], + attention_mask_placeholder1, + inputs1[2], + pkv0_placeholder, + pkv1_placeholder, + ), + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module1, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module1 = torch_mlir.compile( + ts_g, + ( + inputs1[0], + attention_mask_placeholder1, + inputs1[2], + pkv0_placeholder, + pkv1_placeholder, + ), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + module1 = self.write_in_dynamic_inputs1(str(module1), 138) + + module_combined = self.combine_mlir_scripts( + module0, module1, f"{idx}_full.mlir" + ) + mlirs.append(module_combined) + + if vmfb_path.exists(): + # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + None, + device=device, + device_idx=idx % 4, + mlir_dialect="tm_tensor", + mmap=False, + ) + module.load_module(vmfb_path) + else: + print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + mlirs[idx], + device=device, + device_idx=idx % 4, + mlir_dialect="tm_tensor", + mmap=False, + ) + module.save_module( + module_name=f"{idx}_full", + extra_args=[ + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + module.load_module(vmfb_path) + modules.append(module) + return mlirs, modules + + def get_sharded_model(self, device="cpu"): + # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess + # please don't change it + SAMPLE_INPUT_LEN = 137 + vicuna_model = self.get_src_model() + + if self.precision in ["int4", "int8"]: + print("Applying weight quantization..") + weight_bit_width = 4 if self.precision == "int4" else 8 + quantize_model( + get_model_impl(vicuna_model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_type="float", + weight_quant_granularity="per_group", + weight_group_size=self.weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") + + placeholder_input0 = ( + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), + torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64), + ) + + placeholder_input1 = ( + torch.zeros([1, 1, 4096]), + torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), + torch.zeros([1, 1], dtype=torch.int64), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + ) + + norm = VicunaNorm(vicuna_model.model.norm) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.norm(?:\.|\s|$)" + ) + print(device_idx) + norm = self.compile_norm( + norm, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + + embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" + ) + print(device_idx) + embeddings = self.compile_embedding( + embeddings, + (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), + device=self.device, + device_idx=device_idx, + ) + + lmhead = LMHead(vicuna_model.lm_head) + device_idx = self.get_device_index( + r"vicuna\.model\.lm_head(?:\.|\s|$)" + ) + print(device_idx) + lmhead = self.compile_lmhead( + lmhead, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + + layers0 = [ + FirstVicunaLayer(layer) for layer in vicuna_model.model.layers + ] + layers1 = [ + SecondVicunaLayer(layer) for layer in vicuna_model.model.layers + ] + _, modules = self.compile_to_vmfb_one_model( + placeholder_input0, + layers0, + placeholder_input1, + layers1, + device=device, + ) + shark_layers = [CompiledVicunaLayer(m) for m in modules] + + sharded_model = ShardedVicunaModel( + vicuna_model, + shark_layers, + lmhead, + embeddings, + norm, + ) + return sharded_model + + def compile(self, device="cpu"): + return self.get_sharded_model(device=device) + + def generate(self, prompt, cli=False): + # TODO: refactor for cleaner integration + + tokens_generated = [] + _past_key_values = None + _token = None + detoks_generated = [] + for iteration in range(self.max_num_tokens): + params = { + "prompt": prompt, + "is_first": iteration == 0, + "token": _token, + "past_key_values": _past_key_values, + } + + generated_token_op = self.generate_new_token(params=params) + + _token = generated_token_op["token"] + _past_key_values = generated_token_op["past_key_values"] + _detok = generated_token_op["detok"] + + if _token == 2: + break + detoks_generated.append(_detok) + tokens_generated.append(_token) + + for i in range(len(tokens_generated)): + if type(tokens_generated[i]) != int: + tokens_generated[i] = int(tokens_generated[i][0]) + result_output = self.tokenizer.decode(tokens_generated) + return result_output + + def generate_new_token(self, params): + is_first = params["is_first"] + if is_first: + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + output = self.shark_model.forward(input_ids, is_first=is_first) + else: + token = params["token"] + past_key_values = params["past_key_values"] + input_ids = [token] + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + output = self.shark_model.forward( + input_ids, past_key_values=past_key_values, is_first=is_first + ) + + _logits = output["logits"] + _past_key_values = output["past_key_values"] + _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) + _detok = self.tokenizer.decode(_token) + + ret_dict = { + "token": _token, + "detok": _detok, + "past_key_values": _past_key_values, + } + + print(f" token : {_token} | detok : {_detok}") + + return ret_dict + + def autocomplete(self, prompt): + # use First vic alone to complete a story / prompt / sentence. + pass + + +class UnshardedVicuna(SharkLLMBase): + def __init__( + self, + model_name, + hf_model_path="TheBloke/vicuna-7B-1.1-HF", + max_num_tokens=512, + device="cuda", + precision="fp32", + first_vicuna_mlir_path=None, + second_vicuna_mlir_path=None, + first_vicuna_vmfb_path=None, + second_vicuna_vmfb_path=None, + load_mlir_from_shark_tank=True, + low_device_memory=False, + weight_group_size=128, + ) -> None: + super().__init__(model_name, hf_model_path, max_num_tokens) + self.max_sequence_length = 256 + self.device = device + self.precision = precision + self.first_vicuna_vmfb_path = first_vicuna_vmfb_path + self.second_vicuna_vmfb_path = second_vicuna_vmfb_path + self.first_vicuna_mlir_path = first_vicuna_mlir_path + self.second_vicuna_mlir_path = second_vicuna_mlir_path + self.load_mlir_from_shark_tank = load_mlir_from_shark_tank + self.low_device_memory = low_device_memory + self.weight_group_size = weight_group_size + self.first_vic = None + self.second_vic = None + if self.first_vicuna_mlir_path == None: + self.first_vicuna_mlir_path = self.get_model_path() + if self.second_vicuna_mlir_path == None: + self.second_vicuna_mlir_path = self.get_model_path("second") + if self.first_vicuna_vmfb_path == None: + self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb") + if self.second_vicuna_vmfb_path == None: + self.second_vicuna_vmfb_path = self.get_model_path( + "second", "vmfb" + ) + self.tokenizer = self.get_tokenizer() + self.shark_model = self.compile() + + def get_model_path(self, model_number="first", suffix="mlir"): + safe_device = "_".join(self.device.split("-")) + if suffix == "mlir": + return Path(f"{model_number}_vicuna_{self.precision}.{suffix}") + return Path( + f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}" + ) + + def get_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, use_fast=False + ) + return tokenizer + + def get_src_model(self): + kwargs = {"torch_dtype": torch.float} + vicuna_model = AutoModelForCausalLM.from_pretrained( + self.hf_model_path, **kwargs + ) + return vicuna_model + + def compile_first_vicuna(self): + vmfb = get_vmfb_from_path( + self.first_vicuna_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + return vmfb + + # Compilation path needs some more work before it is functional + print( + f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n" + f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}" + ) + if self.first_vicuna_mlir_path.exists(): + with open(self.first_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + else: + mlir_generated = False + if self.load_mlir_from_shark_tank: + if self.precision in ["fp32", "fp16", "int8", "int4"]: + # download MLIR from shark_tank + download_public_file( + f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", + self.first_vicuna_mlir_path.absolute(), + single_file=True, + ) + if self.first_vicuna_mlir_path.exists(): + with open(self.first_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + mlir_generated = True + else: + raise ValueError( + f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}" + " after downloading! Please check path and try again" + ) + else: + print( + f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device." + ) + + if not mlir_generated: + compilation_prompt = "".join(["0" for _ in range(17)]) + compilation_input_ids = self.tokenizer( + compilation_prompt + ).input_ids + compilation_input_ids = torch.tensor( + compilation_input_ids + ).reshape([1, 19]) + firstVicunaCompileInput = (compilation_input_ids,) + model = FirstVicuna( + self.hf_model_path, self.precision, self.weight_group_size + ) + + print(f"[DEBUG] generating torchscript graph") + ts_graph = import_with_fx( + model, + firstVicunaCompileInput, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + del model + + firstVicunaCompileInput = list(firstVicunaCompileInput) + firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like( + firstVicunaCompileInput[0], dynamic_axes=[1] + ) + firstVicunaCompileInput = tuple(firstVicunaCompileInput) + + print(f"[DEBUG] generating torch mlir") + if self.precision in ["int4", "int8"]: + module = torch_mlir.compile( + ts_graph, + [*firstVicunaCompileInput], + output_type=torch_mlir.OutputType.TORCH, + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_graph, + [*firstVicunaCompileInput], + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + del ts_graph + + def remove_constant_dim(line): + if "19x" in line: + line = re.sub("19x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim)", line + ) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", + "tensor.empty(%dim, %dim)", + line, + ) + if "arith.cmpi" in line: + line = re.sub("c19", "dim", line) + if " 19," in line: + line = re.sub(" 19,", " %dim,", line) + return line + + module = str(module) + new_lines = [] + + print(f"[DEBUG] rewriting torch_mlir file") + for line in module.splitlines(): + line = remove_constant_dim(line) + if "%0 = tensor.empty(%dim) : tensor" in line: + new_lines.append( + "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" + ) + if ( + "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" + in line + ): + continue + + new_lines.append(line) + + module = "\n".join(new_lines) + print(f"[DEBUG] converting to bytecode") + del new_lines + module = module.encode("UTF-8") + module = BytesIO(module) + bytecode = module.read() + del module + + print(f"[DEBUG] writing mlir to file") + f_ = open(self.first_vicuna_mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" + ) + path = shark_module.save_module( + self.first_vicuna_vmfb_path.parent.absolute(), + self.first_vicuna_vmfb_path.stem, + extra_args=[ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + print("Saved first vic vmfb at ", str(path)) + shark_module.load_module(path) + + return shark_module + + def compile_second_vicuna(self): + vmfb = get_vmfb_from_path( + self.second_vicuna_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + return vmfb + + # Compilation path needs some more work before it is functional + print( + f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}" + ) + if self.second_vicuna_mlir_path.exists(): + with open(self.second_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + else: + mlir_generated = False + if self.load_mlir_from_shark_tank: + if self.precision in ["fp32", "fp16", "int8", "int4"]: + # download MLIR from shark_tank + download_public_file( + f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}", + self.second_vicuna_mlir_path.absolute(), + single_file=True, + ) + if self.second_vicuna_mlir_path.exists(): + with open(self.second_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + mlir_generated = True + else: + raise ValueError( + f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}" + " after downloading! Please check path and try again" + ) + else: + print( + "Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device." + ) + + if not mlir_generated: + compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) + pkv = tuple( + (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) + for _ in range(64) + ) + secondVicunaCompileInput = (compilation_input_ids,) + pkv + model = SecondVicuna( + self.hf_model_path, self.precision, self.weight_group_size + ) + + print(f"[DEBUG] generating torchscript graph") + ts_graph = import_with_fx( + model, + secondVicunaCompileInput, + precision=self.precision, + f16_input_mask=[False] + [True] * 64, + mlir_type="torchscript", + ) + if self.precision == "fp16": + secondVicunaCompileInput = get_f16_inputs( + secondVicunaCompileInput, + True, + f16_input_mask=[False] + [True] * 64, + ) + secondVicunaCompileInput = list(secondVicunaCompileInput) + for i in range(len(secondVicunaCompileInput)): + if i != 0: + secondVicunaCompileInput[ + i + ] = torch_mlir.TensorPlaceholder.like( + secondVicunaCompileInput[i], dynamic_axes=[2] + ) + secondVicunaCompileInput = tuple(secondVicunaCompileInput) + + print(f"[DEBUG] generating torch mlir") + if self.precision in ["int4", "int8"]: + module = torch_mlir.compile( + ts_graph, + [*secondVicunaCompileInput], + output_type=torch_mlir.OutputType.TORCH, + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_graph, + [*secondVicunaCompileInput], + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + + def remove_constant_dim(line): + if "c19_i64" in line: + line = re.sub("c19_i64", "dim_i64", line) + if "19x" in line: + line = re.sub("19x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim)", line + ) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", + "tensor.empty(%dim, %dim)", + line, + ) + if "arith.cmpi" in line: + line = re.sub("c19", "dim", line) + if " 19," in line: + line = re.sub(" 19,", " %dim,", line) + if "20x" in line: + line = re.sub("20x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dimp1)", line + ) + if " 20," in line: + line = re.sub(" 20,", " %dimp1,", line) + return line + + module_str = str(module) + new_lines = [] + + print(f"[DEBUG] rewriting torch_mlir file") + for line in module_str.splitlines(): + if "%c19_i64 = arith.constant 19 : i64" in line: + new_lines.append("%c2 = arith.constant 2 : index") + new_lines.append( + f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>" + ) + new_lines.append( + "%dim_i64 = arith.index_cast %dim_4_int : index to i64" + ) + continue + if "%c2 = arith.constant 2 : index" in line: + continue + if "%c20_i64 = arith.constant 20 : i64" in line: + new_lines.append("%c1_i64 = arith.constant 1 : i64") + new_lines.append( + "%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" + ) + new_lines.append( + "%dimp1 = arith.index_cast %c20_i64 : i64 to index" + ) + continue + line = remove_constant_dim(line) + new_lines.append(line) + + module_str = "\n".join(new_lines) + print(f"[DEBUG] converting to bytecode") + bytecode = module_str.encode("UTF-8") + bytecode_stream = BytesIO(bytecode) + bytecode = bytecode_stream.read() + + print(f"[DEBUG] writing mlir to file") + f_ = open(self.second_vicuna_mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" + ) + + path = shark_module.save_module( + self.second_vicuna_vmfb_path.parent.absolute(), + self.second_vicuna_vmfb_path.stem, + extra_args=[ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + print("Saved second vic vmfb at ", str(path)) + shark_module.load_module(self.second_vicuna_vmfb_path) + + # self.shark_module = shark_module + return shark_module + + def compile(self): + # Cannot load both the models in the memory at once + # due to memory constraints, hence on demand compilation + # is being used until the space is enough for both models + + # Testing : DO NOT Download Vmfbs if not found. Modify later + # download vmfbs for A100 + if ( + not self.first_vicuna_vmfb_path.exists() + and self.device in ["cuda", "cpu"] + and self.precision in ["fp32", "fp16"] + ): + # combinations that are still in the works + if not (self.device == "cuda" and self.precision == "fp16"): + # Will generate vmfb on device + pass + else: + download_public_file( + f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}", + self.first_vicuna_vmfb_path.absolute(), + single_file=True, + ) + else: + # get first vic + # TODO: Remove after testing to avoid memory overload + # fvic_shark_model = self.compile_first_vicuna() + pass + if ( + not self.second_vicuna_vmfb_path.exists() + and self.device in ["cuda", "cpu"] + and self.precision in ["fp32", "fp16"] + ): + # combinations that are still in the works + if not (self.device == "cuda" and self.precision == "fp16"): + # Will generate vmfb on device + pass + else: + download_public_file( + f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}", + self.second_vicuna_vmfb_path.absolute(), + single_file=True, + ) + else: + # get second vic + # TODO: Remove after testing to avoid memory overload + # svic_shark_model = self.compile_second_vicuna() + pass + + return None + # return tuple of shark_modules once mem is supported + # return fvic_shark_model, svic_shark_model + + def decode_tokens(self, res_tokens): + for i in range(len(res_tokens)): + if type(res_tokens[i]) != int: + res_tokens[i] = int(res_tokens[i][0]) + + res_str = self.tokenizer.decode(res_tokens) + return res_str + + def generate(self, prompt, cli=False): + # TODO: refactor for cleaner integration + import gc + + if not self.low_device_memory: + if self.first_vic == None: + self.first_vic = self.compile_first_vicuna() + if self.second_vic == None: + self.second_vic = self.compile_second_vicuna() + res_tokens = [] + params = { + "prompt": prompt, + "is_first": True, + "fv": self.compile_first_vicuna() + if self.first_vic == None + else self.first_vic, + } + + generated_token_op = self.generate_new_token(params=params) + + token = generated_token_op["token"] + logits = generated_token_op["logits"] + pkv = generated_token_op["pkv"] + detok = generated_token_op["detok"] + yield detok + + res_tokens.append(token) + if cli: + print(f"Assistant: {detok}", end=" ", flush=True) + + # Clear First Vic from Memory (main and cuda) + if self.low_device_memory: + del params + torch.cuda.empty_cache() + gc.collect() + + for _ in range(self.max_num_tokens - 2): + params = { + "prompt": None, + "is_first": False, + "logits": logits, + "pkv": pkv, + "sv": self.compile_second_vicuna() + if self.second_vic == None + else self.second_vic, + } + + generated_token_op = self.generate_new_token(params=params) + + token = generated_token_op["token"] + logits = generated_token_op["logits"] + pkv = generated_token_op["pkv"] + detok = generated_token_op["detok"] + + if token == 2: + break + res_tokens.append(token) + if detok == "<0x0A>": + if cli: + print("\n", end="", flush=True) + else: + if cli: + print(f"{detok}", end=" ", flush=True) + + if len(res_tokens) % 3 == 0: + part_str = self.decode_tokens(res_tokens) + yield part_str + + if self.device == "cuda": + del sec_vic, pkv, logits + torch.cuda.empty_cache() + gc.collect() + + res_str = self.decode_tokens(res_tokens) + # print(f"[DEBUG] final output : \n{res_str}") + yield res_str + + def generate_new_token(self, params, debug=False): + def forward_first(first_vic, prompt, cache_outputs=False): + input_ids = self.tokenizer(prompt).input_ids + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + firstVicunaInput = (input_ids,) + assert first_vic is not None + output_first_vicuna = first_vic("forward", firstVicunaInput) + output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:]) + logits_first_vicuna = torch.tensor(output_first_vicuna[0]) + if cache_outputs: + torch.save( + logits_first_vicuna, "logits_first_vicuna_tensor.pt" + ) + torch.save( + output_first_vicuna_tensor, "output_first_vicuna_tensor.pt" + ) + token = torch.argmax( + torch.tensor(logits_first_vicuna)[:, -1, :], dim=1 + ) + return token, logits_first_vicuna, output_first_vicuna_tensor + + def forward_second(sec_vic, inputs=None, load_inputs=False): + if inputs is not None: + logits = inputs[0] + pkv = inputs[1:] + elif load_inputs: + pkv = torch.load("output_first_vicuna_tensor.pt") + pkv = tuple(torch.tensor(x) for x in pkv) + logits = torch.load("logits_first_vicuna_tensor.pt") + else: + print( + "Either inputs must be given, or load_inputs must be true" + ) + return None + token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1) + token = token.to(torch.int64).reshape([1, 1]) + secondVicunaInput = (token,) + tuple(pkv) + + secondVicunaOutput = sec_vic("forward", secondVicunaInput) + new_pkv = secondVicunaOutput[1:] + new_logits = secondVicunaOutput[0] + new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1) + return new_token, new_logits, new_pkv + + is_first = params["is_first"] + + if is_first: + prompt = params["prompt"] + fv = params["fv"] + token, logits, pkv = forward_first( + fv, # self.shark_model[0], + prompt=prompt, + cache_outputs=False, + ) + else: + _logits = params["logits"] + _pkv = params["pkv"] + inputs = (_logits,) + tuple(_pkv) + sv = params["sv"] + token, logits, pkv = forward_second( + sv, # self.shark_model[1], + inputs=inputs, + load_inputs=False, + ) + + detok = self.tokenizer.decode(token) + if debug: + print( + f"[DEBUG] is_first: {is_first} |" + f" token : {token} | detok : {detok}" + ) + ret_dict = { + "token": token, + "logits": logits, + "pkv": pkv, + "detok": detok, + } + return ret_dict + + def autocomplete(self, prompt): + # use First vic alone to complete a story / prompt / sentence. + pass + if __name__ == "__main__": args, unknown = parser.parse_known_args() @@ -91,7 +1519,7 @@ else Path(args.second_vicuna_vmfb_path) ) - vic = vp.Vicuna( + vic = UnshardedVicuna( "vicuna", device=args.device, precision=args.precision, @@ -100,6 +1528,7 @@ first_vicuna_vmfb_path=first_vic_vmfb_path, second_vicuna_vmfb_path=second_vic_vmfb_path, load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, + weight_group_size=args.weight_group_size, ) else: if args.config is not None: @@ -108,11 +1537,12 @@ config_file.close() else: config_json = None - vic = vsp.Vicuna( + vic = ShardedVicuna( "vicuna", device=args.device, precision=args.precision, config_json=config_json, + weight_group_size=args.weight_group_size, ) prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py index 1753c13e77..3f84aeb2ae 100644 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_model.py @@ -1,14 +1,39 @@ import torch from transformers import AutoModelForCausalLM +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + class FirstVicuna(torch.nn.Module): - def __init__(self, model_path): + def __init__(self, model_path, precision="fp32", weight_group_size=128): super().__init__() kwargs = {"torch_dtype": torch.float32} self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + if precision in ["int4", "int8"]: + print("First Vicuna applying weight quantization..") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + get_model_impl(self.model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_type="float", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") def forward(self, input_ids): op = self.model(input_ids=input_ids, use_cache=True) @@ -22,12 +47,34 @@ def forward(self, input_ids): class SecondVicuna(torch.nn.Module): - def __init__(self, model_path): + def __init__(self, model_path, precision="fp32", weight_group_size=128): super().__init__() kwargs = {"torch_dtype": torch.float32} self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + if precision in ["int4", "int8"]: + print("Second Vicuna applying weight quantization..") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + get_model_impl(self.model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_type="float", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") def forward( self, diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py deleted file mode 100644 index c859d7e243..0000000000 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ /dev/null @@ -1,612 +0,0 @@ -from apps.language_models.src.model_wrappers.vicuna_model import ( - FirstVicuna, - SecondVicuna, -) -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from apps.language_models.utils import ( - get_vmfb_from_path, -) - -from io import BytesIO -from pathlib import Path -from shark.shark_downloader import download_public_file -from shark.shark_importer import import_with_fx, get_f16_inputs -from shark.shark_inference import SharkInference -from transformers import AutoTokenizer, AutoModelForCausalLM - -import re -import torch -import torch_mlir -import os - - -class Vicuna(SharkLLMBase): - def __init__( - self, - model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", - max_num_tokens=512, - device="cuda", - precision="fp32", - first_vicuna_mlir_path=None, - second_vicuna_mlir_path=None, - first_vicuna_vmfb_path=None, - second_vicuna_vmfb_path=None, - load_mlir_from_shark_tank=True, - low_device_memory=False, - ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) - self.max_sequence_length = 256 - self.device = device - if not load_mlir_from_shark_tank and precision in ["int4", "int8"]: - print( - "int4 and int8 are only available from SHARK tank, please set --load_mlir_from_shark_tank, using fp32 now" - ) - precision = "fp32" - self.precision = precision - self.first_vicuna_vmfb_path = first_vicuna_vmfb_path - self.second_vicuna_vmfb_path = second_vicuna_vmfb_path - self.first_vicuna_mlir_path = first_vicuna_mlir_path - self.second_vicuna_mlir_path = second_vicuna_mlir_path - self.load_mlir_from_shark_tank = load_mlir_from_shark_tank - self.low_device_memory = low_device_memory - self.first_vic = None - self.second_vic = None - if self.first_vicuna_mlir_path == None: - self.first_vicuna_mlir_path = self.get_model_path() - if self.second_vicuna_mlir_path == None: - self.second_vicuna_mlir_path = self.get_model_path("second") - if self.first_vicuna_vmfb_path == None: - self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb") - if self.second_vicuna_vmfb_path == None: - self.second_vicuna_vmfb_path = self.get_model_path( - "second", "vmfb" - ) - self.tokenizer = self.get_tokenizer() - self.shark_model = self.compile() - - def get_model_path(self, model_number="first", suffix="mlir"): - safe_device = "_".join(self.device.split("-")) - if suffix == "mlir": - return Path(f"{model_number}_vicuna_{self.precision}.{suffix}") - return Path( - f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}" - ) - - def get_tokenizer(self): - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, use_fast=False - ) - return tokenizer - - def get_src_model(self): - kwargs = {"torch_dtype": torch.float} - vicuna_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, **kwargs - ) - return vicuna_model - - def compile_first_vicuna(self): - vmfb = get_vmfb_from_path( - self.first_vicuna_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - return vmfb - - # Compilation path needs some more work before it is functional - - print( - f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n" - f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}" - ) - if self.first_vicuna_mlir_path.exists(): - with open(self.first_vicuna_mlir_path, "rb") as f: - bytecode = f.read() - else: - mlir_generated = False - if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16", "int8", "int4"]: - # download MLIR from shark_tank - download_public_file( - f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", - self.first_vicuna_mlir_path.absolute(), - single_file=True, - ) - if self.first_vicuna_mlir_path.exists(): - with open(self.first_vicuna_mlir_path, "rb") as f: - bytecode = f.read() - mlir_generated = True - else: - raise ValueError( - f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}" - " after downloading! Please check path and try again" - ) - else: - print( - f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device." - ) - - if not mlir_generated: - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = self.tokenizer( - compilation_prompt - ).input_ids - compilation_input_ids = torch.tensor( - compilation_input_ids - ).reshape([1, 19]) - firstVicunaCompileInput = (compilation_input_ids,) - model = FirstVicuna(self.hf_model_path) - - print(f"[DEBUG] generating torchscript graph") - ts_graph = import_with_fx( - model, - firstVicunaCompileInput, - is_f16=self.precision == "fp16", - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - del model - print(f"[DEBUG] generating torch mlir") - - firstVicunaCompileInput = list(firstVicunaCompileInput) - firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like( - firstVicunaCompileInput[0], dynamic_axes=[1] - ) - firstVicunaCompileInput = tuple(firstVicunaCompileInput) - module = torch_mlir.compile( - ts_graph, - [*firstVicunaCompileInput], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - del ts_graph - - def remove_constant_dim(line): - if "19x" in line: - line = re.sub("19x", "?x", line) - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim)", line - ) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", - "tensor.empty(%dim, %dim)", - line, - ) - if "arith.cmpi" in line: - line = re.sub("c19", "dim", line) - if " 19," in line: - line = re.sub(" 19,", " %dim,", line) - return line - - module = str(module) - new_lines = [] - - print(f"[DEBUG] rewriting torch_mlir file") - for line in module.splitlines(): - line = remove_constant_dim(line) - if "%0 = tensor.empty(%dim) : tensor" in line: - new_lines.append( - "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" - ) - if ( - "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" - in line - ): - continue - - new_lines.append(line) - - module = "\n".join(new_lines) - - print(f"[DEBUG] converting to bytecode") - del new_lines - module = module.encode("UTF-8") - module = BytesIO(module) - bytecode = module.read() - del module - - print(f"[DEBUG] writing mlir to file") - f_ = open(self.first_vicuna_mlir_path, "wb") - f_.write(bytecode) - f_.close() - - shark_module = SharkInference( - mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" - ) - path = shark_module.save_module( - self.first_vicuna_vmfb_path.parent.absolute(), - self.first_vicuna_vmfb_path.stem, - extra_args=[ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ], - ) - print("Saved first vic vmfb at vmfb at ", str(path)) - shark_module.load_module(path) - - return shark_module - - def compile_second_vicuna(self): - vmfb = get_vmfb_from_path( - self.second_vicuna_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - return vmfb - - # Compilation path needs some more work before it is functional - print( - f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}" - ) - if self.second_vicuna_mlir_path.exists(): - with open(self.second_vicuna_mlir_path, "rb") as f: - bytecode = f.read() - else: - mlir_generated = False - if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16", "int8", "int4"]: - # download MLIR from shark_tank - download_public_file( - f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}", - self.second_vicuna_mlir_path.absolute(), - single_file=True, - ) - if self.second_vicuna_mlir_path.exists(): - with open(self.second_vicuna_mlir_path, "rb") as f: - bytecode = f.read() - mlir_generated = True - else: - raise ValueError( - f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}" - " after downloading! Please check path and try again" - ) - else: - print( - "Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device." - ) - - if not mlir_generated: - compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) - pkv = tuple( - (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) - for _ in range(64) - ) - secondVicunaCompileInput = (compilation_input_ids,) + pkv - model = SecondVicuna(self.hf_model_path) - ts_graph = import_with_fx( - model, - secondVicunaCompileInput, - is_f16=self.precision == "fp16", - f16_input_mask=[False] + [True] * 64, - mlir_type="torchscript", - ) - if self.precision == "fp16": - secondVicunaCompileInput = get_f16_inputs( - secondVicunaCompileInput, - True, - f16_input_mask=[False] + [True] * 64, - ) - secondVicunaCompileInput = list(secondVicunaCompileInput) - for i in range(len(secondVicunaCompileInput)): - if i != 0: - secondVicunaCompileInput[ - i - ] = torch_mlir.TensorPlaceholder.like( - secondVicunaCompileInput[i], dynamic_axes=[2] - ) - secondVicunaCompileInput = tuple(secondVicunaCompileInput) - module = torch_mlir.compile( - ts_graph, - [*secondVicunaCompileInput], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - def remove_constant_dim(line): - if "c19_i64" in line: - line = re.sub("c19_i64", "dim_i64", line) - if "19x" in line: - line = re.sub("19x", "?x", line) - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim)", line - ) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", - "tensor.empty(%dim, %dim)", - line, - ) - if "arith.cmpi" in line: - line = re.sub("c19", "dim", line) - if " 19," in line: - line = re.sub(" 19,", " %dim,", line) - if "20x" in line: - line = re.sub("20x", "?x", line) - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dimp1)", line - ) - if " 20," in line: - line = re.sub(" 20,", " %dimp1,", line) - return line - - module_str = str(module) - new_lines = [] - - for line in module_str.splitlines(): - if "%c19_i64 = arith.constant 19 : i64" in line: - new_lines.append("%c2 = arith.constant 2 : index") - new_lines.append( - f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>" - ) - new_lines.append( - "%dim_i64 = arith.index_cast %dim_4_int : index to i64" - ) - continue - if "%c2 = arith.constant 2 : index" in line: - continue - if "%c20_i64 = arith.constant 20 : i64" in line: - new_lines.append("%c1_i64 = arith.constant 1 : i64") - new_lines.append( - "%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" - ) - new_lines.append( - "%dimp1 = arith.index_cast %c20_i64 : i64 to index" - ) - continue - line = remove_constant_dim(line) - new_lines.append(line) - - module_str = "\n".join(new_lines) - bytecode = module_str.encode("UTF-8") - bytecode_stream = BytesIO(bytecode) - bytecode = bytecode_stream.read() - f_ = open(self.second_vicuna_mlir_path, "wb") - f_.write(bytecode) - f_.close() - - shark_module = SharkInference( - mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" - ) - - path = shark_module.save_module( - self.second_vicuna_vmfb_path.parent.absolute(), - self.second_vicuna_vmfb_path.stem, - extra_args=[ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ], - ) - print("Saved vmfb at ", str(path)) - shark_module.load_module(self.second_vicuna_vmfb_path) - - # self.shark_module = shark_module - - return shark_module - - def compile(self): - # Cannot load both the models in the memory at once - # due to memory constraints, hence on demand compilation - # is being used until the space is enough for both models - - # Testing : DO NOT Download Vmfbs if not found. Modify later - # download vmfbs for A100 - if ( - not self.first_vicuna_vmfb_path.exists() - and self.device in ["cuda", "cpu"] - and self.precision in ["fp32", "fp16"] - ): - # combinations that are still in the works - if not (self.device == "cuda" and self.precision == "fp16"): - # Will generate vmfb on device - pass - else: - download_public_file( - f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}", - self.first_vicuna_vmfb_path.absolute(), - single_file=True, - ) - else: - # get first vic - # TODO: Remove after testing to avoid memory overload - # fvic_shark_model = self.compile_first_vicuna() - pass - if ( - not self.second_vicuna_vmfb_path.exists() - and self.device in ["cuda", "cpu"] - and self.precision in ["fp32", "fp16"] - ): - # combinations that are still in the works - if not (self.device == "cuda" and self.precision == "fp16"): - # Will generate vmfb on device - pass - else: - download_public_file( - f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}", - self.second_vicuna_vmfb_path.absolute(), - single_file=True, - ) - else: - # get second vic - # TODO: Remove after testing to avoid memory overload - # svic_shark_model = self.compile_second_vicuna() - pass - - return None - # return tuple of shark_modules once mem is supported - # return fvic_shark_model, svic_shark_model - - def decode_tokens(self, res_tokens): - for i in range(len(res_tokens)): - if type(res_tokens[i]) != int: - res_tokens[i] = int(res_tokens[i][0]) - - res_str = self.tokenizer.decode(res_tokens) - return res_str - - def generate(self, prompt, cli=False): - # TODO: refactor for cleaner integration - import gc - - if not self.low_device_memory: - if self.first_vic == None: - self.first_vic = self.compile_first_vicuna() - if self.second_vic == None: - self.second_vic = self.compile_second_vicuna() - res_tokens = [] - params = { - "prompt": prompt, - "is_first": True, - "fv": self.compile_first_vicuna() - if self.first_vic == None - else self.first_vic, - } - - generated_token_op = self.generate_new_token(params=params) - - token = generated_token_op["token"] - logits = generated_token_op["logits"] - pkv = generated_token_op["pkv"] - detok = generated_token_op["detok"] - yield detok - - res_tokens.append(token) - if cli: - print(f"Assistant: {detok}", end=" ", flush=True) - - # Clear First Vic from Memory (main and cuda) - if self.low_device_memory: - del params - torch.cuda.empty_cache() - gc.collect() - - for _ in range(self.max_num_tokens - 2): - params = { - "prompt": None, - "is_first": False, - "logits": logits, - "pkv": pkv, - "sv": self.compile_second_vicuna() - if self.second_vic == None - else self.second_vic, - } - - generated_token_op = self.generate_new_token(params=params) - - token = generated_token_op["token"] - logits = generated_token_op["logits"] - pkv = generated_token_op["pkv"] - detok = generated_token_op["detok"] - - if token == 2: - break - res_tokens.append(token) - if detok == "<0x0A>": - if cli: - print("\n", end="", flush=True) - else: - if cli: - print(f"{detok}", end=" ", flush=True) - - if len(res_tokens) % 3 == 0: - part_str = self.decode_tokens(res_tokens) - yield part_str - - if self.device == "cuda": - del sec_vic, pkv, logits - torch.cuda.empty_cache() - gc.collect() - - res_str = self.decode_tokens(res_tokens) - # print(f"[DEBUG] final output : \n{res_str}") - yield res_str - - def generate_new_token(self, params, debug=False): - def forward_first(first_vic, prompt, cache_outputs=False): - input_ids = self.tokenizer(prompt).input_ids - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - firstVicunaInput = (input_ids,) - assert first_vic is not None - output_first_vicuna = first_vic("forward", firstVicunaInput) - output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:]) - logits_first_vicuna = torch.tensor(output_first_vicuna[0]) - if cache_outputs: - torch.save( - logits_first_vicuna, "logits_first_vicuna_tensor.pt" - ) - torch.save( - output_first_vicuna_tensor, "output_first_vicuna_tensor.pt" - ) - token = torch.argmax( - torch.tensor(logits_first_vicuna)[:, -1, :], dim=1 - ) - return token, logits_first_vicuna, output_first_vicuna_tensor - - def forward_second(sec_vic, inputs=None, load_inputs=False): - if inputs is not None: - logits = inputs[0] - pkv = inputs[1:] - elif load_inputs: - pkv = torch.load("output_first_vicuna_tensor.pt") - pkv = tuple(torch.tensor(x) for x in pkv) - logits = torch.load("logits_first_vicuna_tensor.pt") - else: - print( - "Either inputs must be given, or load_inputs must be true" - ) - return None - token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1) - token = token.to(torch.int64).reshape([1, 1]) - secondVicunaInput = (token,) + tuple(pkv) - - secondVicunaOutput = sec_vic("forward", secondVicunaInput) - new_pkv = secondVicunaOutput[1:] - new_logits = secondVicunaOutput[0] - new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1) - return new_token, new_logits, new_pkv - - is_first = params["is_first"] - - if is_first: - prompt = params["prompt"] - fv = params["fv"] - token, logits, pkv = forward_first( - fv, # self.shark_model[0], - prompt=prompt, - cache_outputs=False, - ) - else: - _logits = params["logits"] - _pkv = params["pkv"] - inputs = (_logits,) + tuple(_pkv) - sv = params["sv"] - token, logits, pkv = forward_second( - sv, # self.shark_model[1], - inputs=inputs, - load_inputs=False, - ) - - detok = self.tokenizer.decode(token) - if debug: - print( - f"[DEBUG] is_first: {is_first} |" - f" token : {token} | detok : {detok}" - ) - ret_dict = { - "token": token, - "logits": logits, - "pkv": pkv, - "detok": detok, - } - return ret_dict - - def autocomplete(self, prompt): - # use First vic alone to complete a story / prompt / sentence. - pass diff --git a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py b/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py deleted file mode 100644 index cc88ce7bb3..0000000000 --- a/apps/language_models/src/pipelines/vicuna_sharded_pipeline.py +++ /dev/null @@ -1,686 +0,0 @@ -from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( - FirstVicunaLayer, - SecondVicunaLayer, - CompiledVicunaLayer, - ShardedVicunaModel, - LMHead, - LMHeadCompiled, - VicunaEmbedding, - VicunaEmbeddingCompiled, - VicunaNorm, - VicunaNormCompiled, -) -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from shark.shark_importer import import_with_fx -from io import BytesIO -from pathlib import Path -from shark.shark_inference import SharkInference -from transformers import AutoTokenizer, AutoModelForCausalLM -from tqdm import tqdm -from torch_mlir import TensorPlaceholder - - -import re -import torch -import torch_mlir -import os -import json - - -class Vicuna(SharkLLMBase): - # Class representing Sharded Vicuna Model - def __init__( - self, - model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", - max_num_tokens=512, - device="cuda", - precision="fp32", - config_json=None, - ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) - self.max_sequence_length = 256 - self.device = device - self.precision = precision - self.tokenizer = self.get_tokenizer() - self.config = config_json - self.shark_model = self.compile(device=device) - - def get_tokenizer(self): - # Retrieve the tokenizer from Huggingface - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, use_fast=False - ) - return tokenizer - - def get_src_model(self): - # Retrieve the torch model from Huggingface - kwargs = {"torch_dtype": torch.float} - vicuna_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, **kwargs - ) - return vicuna_model - - def write_in_dynamic_inputs0(self, module, dynamic_input_size): - # Current solution for ensuring mlir files support dynamic inputs - # TODO find a more elegant way to implement this - new_lines = [] - for line in module.splitlines(): - line = re.sub(f"{dynamic_input_size}x", "?x", line) - if "?x" in line: - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - line = re.sub(f" {dynamic_input_size},", " %dim,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim", line) - new_lines.append(line) - new_module = "\n".join(new_lines) - return new_module - - def write_in_dynamic_inputs1(self, module, dynamic_input_size): - new_lines = [] - for line in module.splitlines(): - if "dim_42 =" in line: - continue - if f"%c{dynamic_input_size}_i64 =" in line: - new_lines.append( - "%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>" - ) - new_lines.append( - f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" - ) - continue - line = re.sub(f"{dynamic_input_size}x", "?x", line) - line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line) - if "?x" in line: - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim_42)", line - ) - line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim_42\)", - "tensor.empty(%dim_42, %dim_42)", - line, - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim_42", line) - new_lines.append(line) - new_module = "\n".join(new_lines) - return new_module - - def combine_mlir_scripts( - self, first_vicuna_mlir, second_vicuna_mlir, output_name - ): - maps1 = [] - maps2 = [] - constants = set() - f1 = [] - f2 = [] - - for line in first_vicuna_mlir.splitlines(): - if re.search("#map\d*\s*=", line): - maps1.append(line) - elif re.search("arith.constant", line): - constants.add(line) - elif not re.search("module", line): - line = re.sub("forward", "first_vicuna_forward", line) - f1.append(line) - f1 = f1[:-1] - - for i, map_line in enumerate(maps1): - map_var = map_line.split(" ")[0] - map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line) - maps1[i] = map_line - f1 = [ - re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line) - for func_line in f1 - ] - - for line in second_vicuna_mlir.splitlines(): - if re.search("#map\d*\s*=", line): - maps2.append(line) - elif "global_seed" in line: - continue - elif re.search("arith.constant", line): - constants.add(line) - elif not re.search("module", line): - line = re.sub("forward", "second_vicuna_forward", line) - f2.append(line) - f2 = f2[:-1] - - for i, map_line in enumerate(maps2): - map_var = map_line.split(" ")[0] - map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line) - maps2[i] = map_line - f2 = [ - re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line) - for func_line in f2 - ] - - module_start = ( - 'module attributes {torch.debug_module_name = "_lambda"} {' - ) - module_end = "}" - - global_vars = [] - vnames = [] - vdtypes = [] - global_var_loading1 = [] - global_var_loading2 = [] - - for constant in list(constants): - vname, vbody = constant.split("=") - vname = re.sub("%", "", vname) - vname = vname.strip() - vbody = re.sub("arith.constant", "", vbody) - vbody = vbody.strip() - vdtype = vbody.split(":")[1].strip() - fixed_vdtype = vdtype - vdtypes.append(vdtype) - vdtype = re.sub("\d{1,}x", "?x", vdtype) - vnames.append(vname) - global_vars.append( - f"ml_program.global public @{vname}({vbody}) : {fixed_vdtype}" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - - new_f1, new_f2 = [], [] - - for line in f1: - if "func.func" in line: - new_f1.append(line) - for global_var in global_var_loading1: - new_f1.append(global_var) - else: - new_f1.append(line) - - for line in f2: - if "func.func" in line: - new_f2.append(line) - for global_var in global_var_loading1: - new_f2.append(global_var) - else: - new_f2.append(line) - - f1 = new_f1 - f2 = new_f2 - - whole_string = "\n".join( - maps1 - + maps2 - + [module_start] - + global_vars - + f1 - + f2 - + [module_end] - ) - - f_ = open(output_name, "w+") - f_.write(whole_string) - f_.close() - - return whole_string - - def compile_vicuna_layer( - self, - vicuna_layer, - hidden_states, - attention_mask, - position_ids, - past_key_value0=None, - past_key_value1=None, - ): - # Compile a hidden decoder layer of vicuna - if past_key_value0 is None and past_key_value1 is None: - model_inputs = (hidden_states, attention_mask, position_ids) - else: - model_inputs = ( - hidden_states, - attention_mask, - position_ids, - past_key_value0, - past_key_value1, - ) - mlir_bytecode = import_with_fx( - vicuna_layer, - model_inputs, - is_f16=self.precision == "fp16", - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - return mlir_bytecode - - def get_device_index(self, layer_string): - # Get the device index from the config file - # In the event that different device indices are assigned to - # different parts of a layer, a majority vote will be taken and - # everything will be run on the most commonly used device - if self.config is None: - return None - idx_votes = {} - for key in self.config.keys(): - if re.search(layer_string, key): - if int(self.config[key]["gpu"]) in idx_votes.keys(): - idx_votes[int(self.config[key]["gpu"])] += 1 - else: - idx_votes[int(self.config[key]["gpu"])] = 1 - device_idx = max(idx_votes, key=idx_votes.get) - return device_idx - - def compile_lmhead( - self, lmh, hidden_states, device="cpu", device_idx=None - ): - # compile the lm head of the vicuna model - # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"lmhead.mlir") - vmfb_path = Path(f"lmhead.vmfb") - if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() - else: - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - - module = torch_mlir.compile( - lmh, - (hidden_states,), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - - shark_module = SharkInference( - bytecode, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - mmap=False, - ) - if vmfb_path.exists(): - shark_module.load_module(vmfb_path) - else: - shark_module.save_module(module_name="lmhead") - shark_module.load_module(vmfb_path) - compiled_module = LMHeadCompiled(shark_module) - return compiled_module - - def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): - # compile the normalization layer of the vicuna model - # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"norm.mlir") - vmfb_path = Path(f"norm.vmfb") - if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() - else: - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - - module = torch_mlir.compile( - fvn, - (hidden_states,), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - - shark_module = SharkInference( - bytecode, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - mmap=False, - ) - if vmfb_path.exists(): - shark_module.load_module(vmfb_path) - else: - shark_module.save_module(module_name="norm") - shark_module.load_module(vmfb_path) - compiled_module = VicunaNormCompiled(shark_module) - return compiled_module - - def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): - # compile the embedding layer of the vicuna model - # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"embedding.mlir") - vmfb_path = Path(f"embedding.vmfb") - if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() - else: - input_ids = torch_mlir.TensorPlaceholder.like( - input_ids, dynamic_axes=[1] - ) - module = torch_mlir.compile( - fve, - (input_ids,), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - - shark_module = SharkInference( - bytecode, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - mmap=False, - ) - if vmfb_path.exists(): - shark_module.load_module(vmfb_path) - else: - shark_module.save_module(module_name="embedding") - shark_module.load_module(vmfb_path) - compiled_module = VicunaEmbeddingCompiled(shark_module) - - return compiled_module - - def compile_to_vmfb_one_model( - self, inputs0, layers0, inputs1, layers1, device="cpu" - ): - mlirs, modules = [], [] - assert len(layers0) == len(layers1) - for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): - mlir_path = Path(f"{idx}_full.mlir") - vmfb_path = Path(f"{idx}_full.vmfb") - # if vmfb_path.exists(): - # continue - if mlir_path.exists(): - # print(f"Found layer {idx} mlir") - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() - mlirs.append(bytecode) - else: - hidden_states_placeholder0 = TensorPlaceholder.like( - inputs0[0], dynamic_axes=[1] - ) - attention_mask_placeholder0 = TensorPlaceholder.like( - inputs0[1], dynamic_axes=[3] - ) - position_ids_placeholder0 = TensorPlaceholder.like( - inputs0[2], dynamic_axes=[1] - ) - hidden_states_placeholder1 = TensorPlaceholder.like( - inputs1[0], dynamic_axes=[1] - ) - attention_mask_placeholder1 = TensorPlaceholder.like( - inputs1[1], dynamic_axes=[3] - ) - position_ids_placeholder1 = TensorPlaceholder.like( - inputs1[2], dynamic_axes=[1] - ) - pkv0_placeholder = TensorPlaceholder.like( - inputs1[3], dynamic_axes=[2] - ) - pkv1_placeholder = TensorPlaceholder.like( - inputs1[4], dynamic_axes=[2] - ) - print(f"Compiling layer {idx} mlir") - ts_g = self.compile_vicuna_layer( - layer0, inputs0[0], inputs0[1], inputs0[2] - ) - module0 = torch_mlir.compile( - ts_g, - ( - hidden_states_placeholder0, - inputs0[1], - inputs0[2], - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - module0 = self.write_in_dynamic_inputs0(str(module0), 137) - - ts_g = self.compile_vicuna_layer( - layer1, - inputs1[0], - inputs1[1], - inputs1[2], - inputs1[3], - inputs1[4], - ) - module1 = torch_mlir.compile( - ts_g, - ( - inputs1[0], - attention_mask_placeholder1, - inputs1[2], - pkv0_placeholder, - pkv1_placeholder, - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - module1 = self.write_in_dynamic_inputs1(str(module1), 138) - - module_combined = self.combine_mlir_scripts( - module0, module1, f"{idx}_full.mlir" - ) - mlirs.append(module_combined) - - if vmfb_path.exists(): - # print(f"Found layer {idx} vmfb") - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - module = SharkInference( - None, - device=device, - device_idx=idx % 4, - mlir_dialect="tm_tensor", - mmap=False, - ) - module.load_module(vmfb_path) - else: - print(f"Compiling layer {idx} vmfb") - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - module = SharkInference( - mlirs[idx], - device=device, - device_idx=idx % 4, - mlir_dialect="tm_tensor", - mmap=False, - ) - module.save_module( - module_name=f"{idx}_full", - extra_args=[ - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ], - ) - module.load_module(vmfb_path) - modules.append(module) - return mlirs, modules - - def get_sharded_model(self, device="cpu"): - # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess - # please don't change it - SAMPLE_INPUT_LEN = 137 - vicuna_model = self.get_src_model() - placeholder_input0 = ( - torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), - torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), - torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64), - ) - - placeholder_input1 = ( - torch.zeros([1, 1, 4096]), - torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), - torch.zeros([1, 1], dtype=torch.int64), - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), - ) - - norm = VicunaNorm(vicuna_model.model.norm) - device_idx = self.get_device_index( - r"vicuna\.model\.model\.norm(?:\.|\s|$)" - ) - print(device_idx) - norm = self.compile_norm( - norm, - torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), - device=self.device, - device_idx=device_idx, - ) - - embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens) - device_idx = self.get_device_index( - r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" - ) - print(device_idx) - embeddings = self.compile_embedding( - embeddings, - (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), - device=self.device, - device_idx=device_idx, - ) - - lmhead = LMHead(vicuna_model.lm_head) - device_idx = self.get_device_index( - r"vicuna\.model\.lm_head(?:\.|\s|$)" - ) - print(device_idx) - lmhead = self.compile_lmhead( - lmhead, - torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), - device=self.device, - device_idx=device_idx, - ) - - layers0 = [ - FirstVicunaLayer(layer) for layer in vicuna_model.model.layers - ] - - layers1 = [ - SecondVicunaLayer(layer) for layer in vicuna_model.model.layers - ] - - _, modules = self.compile_to_vmfb_one_model( - placeholder_input0, - layers0, - placeholder_input1, - layers1, - device=device, - ) - shark_layers = [CompiledVicunaLayer(m) for m in modules] - - sharded_model = ShardedVicunaModel( - vicuna_model, - shark_layers, - lmhead, - embeddings, - norm, - ) - return sharded_model - - def compile(self, device="cpu"): - return self.get_sharded_model(device=device) - - def generate(self, prompt, cli=False): - # TODO: refactor for cleaner integration - - tokens_generated = [] - _past_key_values = None - _token = None - detoks_generated = [] - for iteration in range(self.max_num_tokens): - params = { - "prompt": prompt, - "is_first": iteration == 0, - "token": _token, - "past_key_values": _past_key_values, - } - - generated_token_op = self.generate_new_token(params=params) - - _token = generated_token_op["token"] - _past_key_values = generated_token_op["past_key_values"] - _detok = generated_token_op["detok"] - - if _token == 2: - break - detoks_generated.append(_detok) - tokens_generated.append(_token) - - for i in range(len(tokens_generated)): - if type(tokens_generated[i]) != int: - tokens_generated[i] = int(tokens_generated[i][0]) - result_output = self.tokenizer.decode(tokens_generated) - return result_output - - def generate_new_token(self, params): - is_first = params["is_first"] - if is_first: - prompt = params["prompt"] - input_ids = self.tokenizer(prompt).input_ids - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - output = self.shark_model.forward(input_ids, is_first=is_first) - else: - token = params["token"] - past_key_values = params["past_key_values"] - input_ids = [token] - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - output = self.shark_model.forward( - input_ids, past_key_values=past_key_values, is_first=is_first - ) - - _logits = output["logits"] - _past_key_values = output["past_key_values"] - _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) - _detok = self.tokenizer.decode(_token) - - ret_dict = { - "token": _token, - "detok": _detok, - "past_key_values": _past_key_values, - } - - print(f" token : {_token} | detok : {_detok}") - - return ret_dict - - def autocomplete(self, prompt): - # use First vic alone to complete a story / prompt / sentence. - pass diff --git a/requirements.txt b/requirements.txt index 8eeafbc7a7..f8e1efc502 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,3 +33,6 @@ sentencepiece # Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors pefile pyinstaller + +# low precision vicuna +brevitas @ git+https://github.com/Xilinx/brevitas.git@llm diff --git a/setup_venv.sh b/setup_venv.sh index e80560607b..35a7f64ecb 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -159,3 +159,5 @@ if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then echo "${Green}Before running examples activate venv with:" echo " ${Green}source $VENV_DIR/bin/activate" fi + +$PYTHON -m pip install git+https://github.com/Xilinx/brevitas.git@llm diff --git a/shark/shark_importer.py b/shark/shark_importer.py index e12f7c0922..e67be22643 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -353,7 +353,7 @@ def add_upcast(fx_g): fx_g.graph.lint() -def transform_fx(fx_g): +def transform_fx(fx_g, quantized=False): import torch kwargs_dict = { @@ -366,6 +366,19 @@ def transform_fx(fx_g): } for node in fx_g.graph.nodes: if node.op == "call_function": + # aten.empty should be filled with zeros. + if node.target in [torch.ops.aten.empty]: + with fx_g.graph.inserting_after(node): + new_node = fx_g.graph.call_function( + torch.ops.aten.zero_, + args=(node,), + ) + node.append(new_node) + node.replace_all_uses_with(new_node) + new_node.args = (node,) + if quantized: + continue + if node.target in [ torch.ops.aten.arange, torch.ops.aten.empty, @@ -427,17 +440,6 @@ def transform_fx(fx_g): new_node.args = (node,) new_node.kwargs = {"dtype": torch.float16} - # aten.empty should be filled with zeros. - if node.target in [torch.ops.aten.empty]: - with fx_g.graph.inserting_after(node): - new_node = fx_g.graph.call_function( - torch.ops.aten.zero_, - args=(node,), - ) - node.append(new_node) - node.replace_all_uses_with(new_node) - new_node.args = (node,) - # Required for cuda debugging. # for node in fx_g.graph.nodes: # if node.op == "call_function": @@ -486,6 +488,7 @@ def flatten_training_input(inputs): return tuple(flattened_input) +# TODO: get rid of is_f16 by using precision # Applies fx conversion to the model and imports the mlir. def import_with_fx( model, @@ -500,10 +503,28 @@ def import_with_fx( mlir_type="linalg", is_dynamic=False, tracing_required=False, + precision="fp32", ): import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions + from typing import List + from brevitas_examples.llm.llm_quant.export import ( + block_quant_layer_level_manager, + ) + from brevitas_examples.llm.llm_quant.export import ( + brevitas_layer_export_mode, + ) + from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import ( + LinearWeightBlockQuantHandlerFwd, + ) + from brevitas_examples.llm.llm_quant.export import replace_call_fn_target + from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import ( + matmul_rhs_group_quant_placeholder, + ) + from brevitas.backport.fx.experimental.proxy_tensor import ( + make_fx as brevitas_make_fx, + ) golden_values = None if debug: @@ -511,26 +532,97 @@ def import_with_fx( golden_values = model(*inputs) except: golden_values = None + + def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: + removed_indexes = [] + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, (list, tuple)): + node_arg = list(node_arg) + node_args_len = len(node_arg) + for i in range(node_args_len): + curr_index = node_args_len - (i + 1) + if node_arg[curr_index] is None: + removed_indexes.append(curr_index) + node_arg.pop(curr_index) + node.args = (tuple(node_arg),) + break + + if len(removed_indexes) > 0: + fx_g.graph.lint() + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + removed_indexes.sort() + return removed_indexes + + def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: + """ + Replace tuple with tuple element in functions that return one-element tuples. + Returns true if an unwrapping took place, and false otherwise. + """ + unwrapped_tuple = False + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, tuple): + if len(node_arg) == 1: + node.args = (node_arg[0],) + unwrapped_tuple = True + break + + if unwrapped_tuple: + fx_g.graph.lint() + fx_g.recompile() + return unwrapped_tuple + # TODO: Control the decompositions. - fx_g = make_fx( - model, - decomposition_table=get_decompositions( - [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - torch.ops.aten.native_layer_norm, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar, - ] - ), - )(*inputs) + decomps_list = [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + torch.ops.aten.native_layer_norm, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar, + ] + if precision in ["int4", "int8"]: + export_context_manager = brevitas_layer_export_mode + export_class = block_quant_layer_level_manager( + export_handlers=[LinearWeightBlockQuantHandlerFwd] + ) + with export_context_manager(model, export_class): + fx_g = brevitas_make_fx( + model, + decomposition_table=get_decompositions(decomps_list), + )(*inputs) + + transform_fx(fx_g, quantized=True) + replace_call_fn_target( + fx_g, + src=matmul_rhs_group_quant_placeholder, + target=torch.ops.brevitas.matmul_rhs_group_quant, + ) + + fx_g.recompile() + removed_none_indexes = _remove_nones(fx_g) + was_unwrapped = _unwrap_single_tuple_return(fx_g) + else: + fx_g = make_fx( + model, + decomposition_table=get_decompositions(decomps_list), + )(*inputs) fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) fx_g.recompile()