diff --git a/.gitmodules b/.gitmodules index 23d23140c7..2412040acb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,7 @@ path = inference/thirdparty/shark-runtime url =https://github.com/nod-ai/SHARK-Runtime.git branch = shark-06032022 +[submodule "third_party/brevitas"] + path = third_party/brevitas + url = https://github.com/Xilinx/brevitas.git + branch = llm diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 54c32c8269..7e47d06ba6 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1,9 +1,45 @@ import argparse +import json +import re +from io import BytesIO from pathlib import Path -from apps.language_models.src.pipelines import vicuna_pipeline as vp -from apps.language_models.src.pipelines import vicuna_sharded_pipeline as vsp +from tqdm import tqdm +from typing import List, Tuple + import torch -import json +import torch_mlir +from torch_mlir import TensorPlaceholder +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from transformers import AutoTokenizer, AutoModelForCausalLM + +from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase +from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( + FirstVicunaLayer, + SecondVicunaLayer, + CompiledFirstVicunaLayer, + CompiledSecondVicunaLayer, + ShardedVicunaModel, + LMHead, + LMHeadCompiled, + VicunaEmbedding, + VicunaEmbeddingCompiled, + VicunaNorm, + VicunaNormCompiled, +) +from apps.language_models.src.model_wrappers.vicuna_model import ( + FirstVicuna, + SecondVicuna, +) +from apps.language_models.utils import ( + get_vmfb_from_path, +) +from shark.shark_downloader import download_public_file +from shark.shark_importer import get_f16_inputs +from shark.shark_importer import import_with_fx +from shark.shark_inference import SharkInference + +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl if __name__ == "__main__": import gc @@ -13,7 +49,6 @@ prog="vicuna runner", description="runs a vicuna model", ) - parser.add_argument( "--precision", "-p", default="fp32", help="fp32, fp16, int8, int4" ) @@ -29,7 +64,6 @@ help="Run model as sharded", ) # TODO: sharded config - parser.add_argument( "--second_vicuna_vmfb_path", default=None, @@ -45,7 +79,6 @@ default=None, help="path to second vicuna mlir", ) - parser.add_argument( "--load_mlir_from_shark_tank", default=False, @@ -58,12 +91,1339 @@ action=argparse.BooleanOptionalAction, help="Run model in cli mode", ) - parser.add_argument( "--config", default=None, help="configuration file", ) +parser.add_argument( + "--weight-group-size", + type=int, + default=128, + help="Group size for per_group weight quantization. Default: 128.", +) + + +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + if len(lhs) == 3 and len(rhs) == 2: + return [lhs[0], lhs[1], rhs[0]] + elif len(lhs) == 2 and len(rhs) == 2: + return [lhs[0], rhs[0]] + else: + raise ValueError("Input shapes not supported.") + + +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: + # output dtype is the dtype of the lhs float input + lhs_rank, lhs_dtype = lhs_rank_dtype + return lhs_dtype + + +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: + return + + +brevitas_matmul_rhs_group_quant_library = [ + brevitas〇matmul_rhs_group_quant〡shape, + brevitas〇matmul_rhs_group_quant〡dtype, + brevitas〇matmul_rhs_group_quant〡has_value_semantics] + + +class ShardedVicuna(SharkLLMBase): + # Class representing Sharded Vicuna Model + def __init__( + self, + model_name, + hf_model_path="TheBloke/vicuna-7B-1.1-HF", + max_num_tokens=512, + device="cuda", + precision="fp32", + config_json=None, + weight_group_size=128, + ) -> None: + super().__init__(model_name, hf_model_path, max_num_tokens) + self.max_sequence_length = 256 + self.device = device + self.precision = precision + self.tokenizer = self.get_tokenizer() + self.config = config_json + self.weight_group_size = weight_group_size + self.shark_model = self.compile(device=device) + + def get_tokenizer(self): + # Retrieve the tokenizer from Huggingface + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, use_fast=False + ) + return tokenizer + + def get_src_model(self): + # Retrieve the torch model from Huggingface + kwargs = {"torch_dtype": torch.float} + vicuna_model = AutoModelForCausalLM.from_pretrained( + self.hf_model_path, **kwargs + ) + return vicuna_model + + def write_in_dynamic_inputs0(self, module, dynamic_input_size): + # Current solution for ensuring mlir files support dynamic inputs + # TODO find a more elegant way to implement this + new_lines = [] + for line in module.splitlines(): + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) + line = re.sub(f" {dynamic_input_size},", " %dim,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + def write_in_dynamic_inputs1(self, module, dynamic_input_size): + new_lines = [] + for line in module.splitlines(): + if "dim_42 =" in line: + continue + if f"%c{dynamic_input_size}_i64 =" in line: + new_lines.append( + "%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>" + ) + new_lines.append( + f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" + ) + continue + line = re.sub(f"{dynamic_input_size}x", "?x", line) + if "?x" in line: + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim_42)", line + ) + line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim_42\)", + "tensor.empty(%dim_42, %dim_42)", + line, + ) + if "arith.cmpi" in line: + line = re.sub(f"c{dynamic_input_size}", "dim_42", line) + new_lines.append(line) + new_module = "\n".join(new_lines) + return new_module + + def compile_vicuna_layer( + self, + vicuna_layer, + hidden_states, + attention_mask, + position_ids, + past_key_value0=None, + past_key_value1=None, + ): + # Compile a hidden decoder layer of vicuna + if past_key_value0 is None and past_key_value1 is None: + model_inputs = (hidden_states, attention_mask, position_ids) + else: + model_inputs = ( + hidden_states, + attention_mask, + position_ids, + past_key_value0, + past_key_value1, + ) + mlir_bytecode = import_with_fx( + vicuna_layer, + model_inputs, + is_f16=self.precision == "fp16", + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + return mlir_bytecode + + def get_device_index(self, layer_string): + # Get the device index from the config file + # In the event that different device indices are assigned to + # different parts of a layer, a majority vote will be taken and + # everything will be run on the most commonly used device + if self.config is None: + return None + idx_votes = {} + for key in self.config.keys(): + if re.search(layer_string, key): + if int(self.config[key]["gpu"]) in idx_votes.keys(): + idx_votes[int(self.config[key]["gpu"])] += 1 + else: + idx_votes[int(self.config[key]["gpu"])] = 1 + device_idx = max(idx_votes, key=idx_votes.get) + return device_idx + + def compile_lmhead( + self, lmh, hidden_states, device="cpu", device_idx=None + ): + # compile the lm head of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"lmhead.mlir") + vmfb_path = Path(f"lmhead.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + lmh, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="lmhead") + shark_module.load_module(vmfb_path) + compiled_module = LMHeadCompiled(shark_module) + return compiled_module + + def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): + # compile the normalization layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"norm.mlir") + vmfb_path = Path(f"norm.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + fvn, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="norm") + shark_module.load_module(vmfb_path) + compiled_module = VicunaNormCompiled(shark_module) + return compiled_module + + def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): + # compile the embedding layer of the vicuna model + # This can be used for both first and second vicuna, so only needs to be run once + mlir_path = Path(f"embedding.mlir") + vmfb_path = Path(f"embedding.vmfb") + if mlir_path.exists(): + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + input_ids = torch_mlir.TensorPlaceholder.like( + input_ids, dynamic_axes=[1] + ) + module = torch_mlir.compile( + fve, + (input_ids,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + bytecode, + device=device, + mlir_dialect="tm_tensor", + device_idx=device_idx, + ) + if vmfb_path.exists(): + shark_module.load_module(vmfb_path) + else: + shark_module.save_module(module_name="embedding") + shark_module.load_module(vmfb_path) + compiled_module = VicunaEmbeddingCompiled(shark_module) + + return compiled_module + + def compile_to_vmfb(self, inputs, layers, device="cpu", is_first=True): + # compile all layers for vmfb + # this needs to be run seperatley for first and second vicuna + mlirs, modules = [], [] + for idx, layer in tqdm(enumerate(layers), desc="Getting mlirs"): + if is_first: + mlir_path = Path(f"{idx}_0.mlir") + vmfb_path = Path(f"{idx}_0.vmfb") + else: + mlir_path = Path(f"{idx}_1.mlir") + vmfb_path = Path(f"{idx}_1.vmfb") + if vmfb_path.exists(): + continue + if mlir_path.exists(): + # print(f"Found layer {idx} mlir") + f_ = open(mlir_path, "rb") + bytecode = f_.read() + f_.close() + else: + hidden_states_placeholder = TensorPlaceholder.like( + inputs[0], dynamic_axes=[1] + ) + attention_mask_placeholder = TensorPlaceholder.like( + inputs[1], dynamic_axes=[3] + ) + position_ids_placeholder = TensorPlaceholder.like( + inputs[2], dynamic_axes=[1] + ) + if not is_first: + pkv0_placeholder = TensorPlaceholder.like( + inputs[3], dynamic_axes=[2] + ) + pkv1_placeholder = TensorPlaceholder.like( + inputs[4], dynamic_axes=[2] + ) + print(f"Compiling layer {idx} mlir") + if is_first: + ts_g = self.compile_vicuna_layer( + layer, inputs[0], inputs[1], inputs[2] + ) + if self.precision in ["int4", "int8"]: + module = torch_mlir.compile( + ts_g, + ( + hidden_states_placeholder, + inputs[1], + inputs[2], + ), + output_type="torch", + backend_legal_ops=[ + "brevitas.matmul_rhs_group_quant" + ], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_g, + ( + hidden_states_placeholder, + inputs[1], + inputs[2], + ), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + else: + ts_g = self.compile_vicuna_layer( + layer, + inputs[0], + inputs[1], + inputs[2], + inputs[3], + inputs[4], + ) + if self.precision in ["int4", "int8"]: + module = torch_mlir.compile( + ts_g, + ( + inputs[0], + attention_mask_placeholder, + inputs[2], + pkv0_placeholder, + pkv1_placeholder, + ), + output_type="torch", + backend_legal_ops=[ + "brevitas.matmul_rhs_group_quant" + ], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_g, + ( + inputs[0], + attention_mask_placeholder, + inputs[2], + pkv0_placeholder, + pkv1_placeholder, + ), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + + if is_first: + module = self.write_in_dynamic_inputs0(str(module), 137) + bytecode = module.encode("UTF-8") + bytecode_stream = BytesIO(bytecode) + bytecode = bytecode_stream.read() + else: + module = self.write_in_dynamic_inputs1(str(module), 138) + bytecode = module.encode("UTF-8") + bytecode_stream = BytesIO(bytecode) + bytecode = bytecode_stream.read() + + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + mlirs.append(bytecode) + + for idx, layer in tqdm(enumerate(layers), desc="compiling modules"): + if is_first: + vmfb_path = Path(f"{idx}_0.vmfb") + if vmfb_path.exists(): + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + None, + device=device, + device_idx=device_idx, + mlir_dialect="tm_tensor", + ) + module.load_module(vmfb_path) + else: + print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"first_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + mlirs[idx], + device=device, + device_idx=device_idx, + mlir_dialect="tm_tensor", + ) + module.save_module( + module_name=f"{idx}_0", + extra_args=[ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + module.load_module(vmfb_path) + modules.append(module) + else: + vmfb_path = Path(f"{idx}_1.vmfb") + if vmfb_path.exists(): + # print(f"Found layer {idx} vmfb") + device_idx = self.get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + None, + device=device, + device_idx=device_idx, + mlir_dialect="tm_tensor", + ) + module.load_module(vmfb_path) + else: + print(f"Compiling layer {idx} vmfb") + device_idx = self.get_device_index( + f"second_vicuna.model.model.layers.{idx}[\s.$]" + ) + module = SharkInference( + mlirs[idx], + device=device, + device_idx=device_idx, + mlir_dialect="tm_tensor", + ) + module.save_module( + module_name=f"{idx}_1", + extra_args=[ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + module.load_module(vmfb_path) + modules.append(module) + + return mlirs, modules + + def get_sharded_model(self, device="cpu"): + # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess + # please don't change it + SAMPLE_INPUT_LEN = 137 + vicuna_model = self.get_src_model() + + if self.precision in ["int4", "int8"]: + print("Applying weight quantization..") + weight_bit_width = 4 if self.precision == "int4" else 8 + quantize_model( + get_model_impl(vicuna_model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_type="float", + weight_quant_granularity="per_group", + weight_group_size=self.weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") + + placeholder_input0 = ( + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), + torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64), + ) + + placeholder_input1 = ( + torch.zeros([1, 1, 4096]), + torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), + torch.zeros([1, 1], dtype=torch.int64), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + torch.zeros([1, 32, SAMPLE_INPUT_LEN, 128]), + ) + + norm = VicunaNorm(vicuna_model.model.norm) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.norm(?:\.|\s|$)" + ) + print(device_idx) + norm = self.compile_norm( + norm, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + + embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens) + device_idx = self.get_device_index( + r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" + ) + print(device_idx) + embeddings = self.compile_embedding( + embeddings, + (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), + device=self.device, + device_idx=device_idx, + ) + + lmhead = LMHead(vicuna_model.lm_head) + device_idx = self.get_device_index( + r"vicuna\.model\.lm_head(?:\.|\s|$)" + ) + print(device_idx) + lmhead = self.compile_lmhead( + lmhead, + torch.zeros([1, SAMPLE_INPUT_LEN, 4096]), + device=self.device, + device_idx=device_idx, + ) + + layers0 = [ + FirstVicunaLayer(layer) for layer in vicuna_model.model.layers + ] + _, modules0 = self.compile_to_vmfb( + placeholder_input0, + layers0, + is_first=True, + device=device, + ) + shark_layers0 = [CompiledFirstVicunaLayer(m) for m in modules0] + + layers1 = [ + SecondVicunaLayer(layer) for layer in vicuna_model.model.layers + ] + _, modules1 = self.compile_to_vmfb( + placeholder_input1, layers1, is_first=False, device=device + ) + shark_layers1 = [CompiledSecondVicunaLayer(m) for m in modules1] + + sharded_model = ShardedVicunaModel( + vicuna_model, + shark_layers0, + shark_layers1, + lmhead, + embeddings, + norm, + ) + return sharded_model + + def compile(self, device="cpu"): + return self.get_sharded_model(device=device) + + def generate(self, prompt, cli=False): + # TODO: refactor for cleaner integration + + tokens_generated = [] + _past_key_values = None + _token = None + detoks_generated = [] + for iteration in range(self.max_num_tokens): + params = { + "prompt": prompt, + "is_first": iteration == 0, + "token": _token, + "past_key_values": _past_key_values, + } + + generated_token_op = self.generate_new_token(params=params) + + _token = generated_token_op["token"] + _past_key_values = generated_token_op["past_key_values"] + _detok = generated_token_op["detok"] + + if _token == 2: + break + detoks_generated.append(_detok) + tokens_generated.append(_token) + + for i in range(len(tokens_generated)): + if type(tokens_generated[i]) != int: + tokens_generated[i] = int(tokens_generated[i][0]) + result_output = self.tokenizer.decode(tokens_generated) + return result_output + + def generate_new_token(self, params): + is_first = params["is_first"] + if is_first: + prompt = params["prompt"] + input_ids = self.tokenizer(prompt).input_ids + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + output = self.shark_model.forward(input_ids, is_first=is_first) + else: + token = params["token"] + past_key_values = params["past_key_values"] + input_ids = [token] + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + output = self.shark_model.forward( + input_ids, past_key_values=past_key_values, is_first=is_first + ) + + _logits = output["logits"] + _past_key_values = output["past_key_values"] + _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) + _detok = self.tokenizer.decode(_token) + + ret_dict = { + "token": _token, + "detok": _detok, + "past_key_values": _past_key_values, + } + + print(f" token : {_token} | detok : {_detok}") + + return ret_dict + + def autocomplete(self, prompt): + # use First vic alone to complete a story / prompt / sentence. + pass + + +class UnshardedVicuna(SharkLLMBase): + def __init__( + self, + model_name, + hf_model_path="TheBloke/vicuna-7B-1.1-HF", + max_num_tokens=512, + device="cuda", + precision="fp32", + first_vicuna_mlir_path=None, + second_vicuna_mlir_path=None, + first_vicuna_vmfb_path=None, + second_vicuna_vmfb_path=None, + load_mlir_from_shark_tank=True, + low_device_memory=False, + weight_group_size=128, + ) -> None: + super().__init__(model_name, hf_model_path, max_num_tokens) + self.max_sequence_length = 256 + self.device = device + self.precision = precision + self.first_vicuna_vmfb_path = first_vicuna_vmfb_path + self.second_vicuna_vmfb_path = second_vicuna_vmfb_path + self.first_vicuna_mlir_path = first_vicuna_mlir_path + self.second_vicuna_mlir_path = second_vicuna_mlir_path + self.load_mlir_from_shark_tank = load_mlir_from_shark_tank + self.low_device_memory = low_device_memory + self.weight_group_size = weight_group_size + self.first_vic = None + self.second_vic = None + if self.first_vicuna_mlir_path == None: + self.first_vicuna_mlir_path = self.get_model_path() + if self.second_vicuna_mlir_path == None: + self.second_vicuna_mlir_path = self.get_model_path("second") + if self.first_vicuna_vmfb_path == None: + self.first_vicuna_vmfb_path = self.get_model_path(suffix="vmfb") + if self.second_vicuna_vmfb_path == None: + self.second_vicuna_vmfb_path = self.get_model_path( + "second", "vmfb" + ) + self.tokenizer = self.get_tokenizer() + self.shark_model = self.compile() + + def get_model_path(self, model_number="first", suffix="mlir"): + safe_device = "_".join(self.device.split("-")) + if suffix == "mlir": + return Path(f"{model_number}_vicuna_{self.precision}.{suffix}") + return Path( + f"{model_number}_vicuna_{self.precision}_{safe_device}.{suffix}" + ) + + def get_tokenizer(self): + tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_path, use_fast=False + ) + return tokenizer + + def get_src_model(self): + kwargs = {"torch_dtype": torch.float} + vicuna_model = AutoModelForCausalLM.from_pretrained( + self.hf_model_path, **kwargs + ) + return vicuna_model + + def compile_first_vicuna(self): + vmfb = get_vmfb_from_path( + self.first_vicuna_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + return vmfb + + # Compilation path needs some more work before it is functional + print( + f"[DEBUG] vmfb not found at {self.first_vicuna_vmfb_path.absolute()}. Trying to work with\n" + f"[DEBUG] mlir path { self.first_vicuna_mlir_path} {'exists' if self.first_vicuna_mlir_path.exists() else 'does not exist'}" + ) + if self.first_vicuna_mlir_path.exists(): + with open(self.first_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + else: + mlir_generated = False + if self.load_mlir_from_shark_tank: + if self.precision in ["fp32", "fp16", "int8", "int4"]: + # download MLIR from shark_tank + download_public_file( + f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", + self.first_vicuna_mlir_path.absolute(), + single_file=True, + ) + if self.first_vicuna_mlir_path.exists(): + with open(self.first_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + mlir_generated = True + else: + raise ValueError( + f"MLIR not found at {self.first_vicuna_mlir_path.absolute()}" + " after downloading! Please check path and try again" + ) + else: + print( + f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device." + ) + + if not mlir_generated: + compilation_prompt = "".join(["0" for _ in range(17)]) + compilation_input_ids = self.tokenizer( + compilation_prompt + ).input_ids + compilation_input_ids = torch.tensor( + compilation_input_ids + ).reshape([1, 19]) + firstVicunaCompileInput = (compilation_input_ids,) + model = FirstVicuna( + self.hf_model_path, self.precision, self.weight_group_size + ) + + print(f"[DEBUG] generating torchscript graph") + ts_graph = import_with_fx( + model, + firstVicunaCompileInput, + is_f16=self.precision == "fp16", + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + del model + + firstVicunaCompileInput = list(firstVicunaCompileInput) + firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like( + firstVicunaCompileInput[0], dynamic_axes=[1] + ) + firstVicunaCompileInput = tuple(firstVicunaCompileInput) + + print(f"[DEBUG] generating torch mlir") + if self.precision in ["int4", "int8"]: + module = torch_mlir.compile( + ts_graph, + [*firstVicunaCompileInput], + output_type=torch_mlir.OutputType.TORCH, + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_graph, + [*firstVicunaCompileInput], + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + del ts_graph + + def remove_constant_dim(line): + if "19x" in line: + line = re.sub("19x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim)", line + ) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", + "tensor.empty(%dim, %dim)", + line, + ) + if "arith.cmpi" in line: + line = re.sub("c19", "dim", line) + if " 19," in line: + line = re.sub(" 19,", " %dim,", line) + return line + + module = str(module) + new_lines = [] + + print(f"[DEBUG] rewriting torch_mlir file") + for line in module.splitlines(): + line = remove_constant_dim(line) + if "%0 = tensor.empty(%dim) : tensor" in line: + new_lines.append( + "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" + ) + if ( + "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" + in line + ): + continue + + new_lines.append(line) + + module = "\n".join(new_lines) + print(f"[DEBUG] converting to bytecode") + del new_lines + module = module.encode("UTF-8") + module = BytesIO(module) + bytecode = module.read() + del module + + print(f"[DEBUG] writing mlir to file") + f_ = open(self.first_vicuna_mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" + ) + path = shark_module.save_module( + self.first_vicuna_vmfb_path.parent.absolute(), + self.first_vicuna_vmfb_path.stem, + extra_args=[ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + print("Saved first vic vmfb at ", str(path)) + shark_module.load_module(path) + + return shark_module + + def compile_second_vicuna(self): + vmfb = get_vmfb_from_path( + self.second_vicuna_vmfb_path, self.device, "tm_tensor" + ) + if vmfb is not None: + return vmfb + + # Compilation path needs some more work before it is functional + print( + f"[DEBUG] mlir path {self.second_vicuna_mlir_path} {'exists' if self.second_vicuna_mlir_path.exists() else 'does not exist'}" + ) + if self.second_vicuna_mlir_path.exists(): + with open(self.second_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + else: + mlir_generated = False + if self.load_mlir_from_shark_tank: + if self.precision in ["fp32", "fp16", "int8", "int4"]: + # download MLIR from shark_tank + download_public_file( + f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}", + self.second_vicuna_mlir_path.absolute(), + single_file=True, + ) + if self.second_vicuna_mlir_path.exists(): + with open(self.second_vicuna_mlir_path, "rb") as f: + bytecode = f.read() + mlir_generated = True + else: + raise ValueError( + f"MLIR not found at {self.second_vicuna_mlir_path.absolute()}" + " after downloading! Please check path and try again" + ) + else: + print( + "Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device." + ) + + if not mlir_generated: + compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) + pkv = tuple( + (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) + for _ in range(64) + ) + secondVicunaCompileInput = (compilation_input_ids,) + pkv + model = SecondVicuna( + self.hf_model_path, self.precision, self.weight_group_size + ) + + print(f"[DEBUG] generating torchscript graph") + ts_graph = import_with_fx( + model, + secondVicunaCompileInput, + is_f16=self.precision == "fp16", + precision=self.precision, + f16_input_mask=[False] + [True] * 64, + mlir_type="torchscript", + ) + if self.precision == "fp16": + secondVicunaCompileInput = get_f16_inputs( + secondVicunaCompileInput, + True, + f16_input_mask=[False] + [True] * 64, + ) + secondVicunaCompileInput = list(secondVicunaCompileInput) + for i in range(len(secondVicunaCompileInput)): + if i != 0: + secondVicunaCompileInput[ + i + ] = torch_mlir.TensorPlaceholder.like( + secondVicunaCompileInput[i], dynamic_axes=[2] + ) + secondVicunaCompileInput = tuple(secondVicunaCompileInput) + + print(f"[DEBUG] generating torch mlir") + if self.precision in ["int4", "int8"]: + module = torch_mlir.compile( + ts_graph, + [*secondVicunaCompileInput], + output_type=torch_mlir.OutputType.TORCH, + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + # TODO: apply --canonicalize to unpack tensor for int4 + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + else: + module = torch_mlir.compile( + ts_graph, + [*secondVicunaCompileInput], + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + + def remove_constant_dim(line): + if "c19_i64" in line: + line = re.sub("c19_i64", "dim_i64", line) + if "19x" in line: + line = re.sub("19x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dim)", line + ) + if "tensor.empty" in line and "?x?" in line: + line = re.sub( + "tensor.empty\(%dim\)", + "tensor.empty(%dim, %dim)", + line, + ) + if "arith.cmpi" in line: + line = re.sub("c19", "dim", line) + if " 19," in line: + line = re.sub(" 19,", " %dim,", line) + if "20x" in line: + line = re.sub("20x", "?x", line) + line = re.sub( + "tensor.empty\(\)", "tensor.empty(%dimp1)", line + ) + if " 20," in line: + line = re.sub(" 20,", " %dimp1,", line) + return line + + module_str = str(module) + new_lines = [] + + print(f"[DEBUG] rewriting torch_mlir file") + for line in module_str.splitlines(): + if "%c19_i64 = arith.constant 19 : i64" in line: + new_lines.append("%c2 = arith.constant 2 : index") + new_lines.append( + f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>" + ) + new_lines.append( + "%dim_i64 = arith.index_cast %dim_4_int : index to i64" + ) + continue + if "%c2 = arith.constant 2 : index" in line: + continue + if "%c20_i64 = arith.constant 20 : i64" in line: + new_lines.append("%c1_i64 = arith.constant 1 : i64") + new_lines.append( + "%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" + ) + new_lines.append( + "%dimp1 = arith.index_cast %c20_i64 : i64 to index" + ) + continue + line = remove_constant_dim(line) + new_lines.append(line) + + module_str = "\n".join(new_lines) + print(f"[DEBUG] converting to bytecode") + bytecode = module_str.encode("UTF-8") + bytecode_stream = BytesIO(bytecode) + bytecode = bytecode_stream.read() + + print(f"[DEBUG] writing mlir to file") + f_ = open(self.second_vicuna_mlir_path, "wb") + f_.write(bytecode) + f_.close() + + shark_module = SharkInference( + mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor" + ) + + path = shark_module.save_module( + self.second_vicuna_vmfb_path.parent.absolute(), + self.second_vicuna_vmfb_path.stem, + extra_args=[ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ], + ) + print("Saved second vic vmfb at ", str(path)) + shark_module.load_module(self.second_vicuna_vmfb_path) + + # self.shark_module = shark_module + return shark_module + + def compile(self): + # Cannot load both the models in the memory at once + # due to memory constraints, hence on demand compilation + # is being used until the space is enough for both models + + # Testing : DO NOT Download Vmfbs if not found. Modify later + # download vmfbs for A100 + if ( + not self.first_vicuna_vmfb_path.exists() + and self.device in ["cuda", "cpu"] + and self.precision in ["fp32", "fp16"] + ): + # combinations that are still in the works + if not (self.device == "cuda" and self.precision == "fp16"): + # Will generate vmfb on device + pass + else: + download_public_file( + f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}", + self.first_vicuna_vmfb_path.absolute(), + single_file=True, + ) + else: + # get first vic + # TODO: Remove after testing to avoid memory overload + # fvic_shark_model = self.compile_first_vicuna() + pass + if ( + not self.second_vicuna_vmfb_path.exists() + and self.device in ["cuda", "cpu"] + and self.precision in ["fp32", "fp16"] + ): + # combinations that are still in the works + if not (self.device == "cuda" and self.precision == "fp16"): + # Will generate vmfb on device + pass + else: + download_public_file( + f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}", + self.second_vicuna_vmfb_path.absolute(), + single_file=True, + ) + else: + # get second vic + # TODO: Remove after testing to avoid memory overload + # svic_shark_model = self.compile_second_vicuna() + pass + + return None + # return tuple of shark_modules once mem is supported + # return fvic_shark_model, svic_shark_model + + def decode_tokens(self, res_tokens): + for i in range(len(res_tokens)): + if type(res_tokens[i]) != int: + res_tokens[i] = int(res_tokens[i][0]) + + res_str = self.tokenizer.decode(res_tokens) + return res_str + + def generate(self, prompt, cli=False): + # TODO: refactor for cleaner integration + import gc + + if not self.low_device_memory: + if self.first_vic == None: + self.first_vic = self.compile_first_vicuna() + if self.second_vic == None: + self.second_vic = self.compile_second_vicuna() + res_tokens = [] + params = { + "prompt": prompt, + "is_first": True, + "fv": self.compile_first_vicuna() + if self.first_vic == None + else self.first_vic, + } + + generated_token_op = self.generate_new_token(params=params) + + token = generated_token_op["token"] + logits = generated_token_op["logits"] + pkv = generated_token_op["pkv"] + detok = generated_token_op["detok"] + yield detok + + res_tokens.append(token) + if cli: + print(f"Assistant: {detok}", end=" ", flush=True) + + # Clear First Vic from Memory (main and cuda) + if self.low_device_memory: + del params + torch.cuda.empty_cache() + gc.collect() + + for _ in range(self.max_num_tokens - 2): + params = { + "prompt": None, + "is_first": False, + "logits": logits, + "pkv": pkv, + "sv": self.compile_second_vicuna() + if self.second_vic == None + else self.second_vic, + } + + generated_token_op = self.generate_new_token(params=params) + + token = generated_token_op["token"] + logits = generated_token_op["logits"] + pkv = generated_token_op["pkv"] + detok = generated_token_op["detok"] + + if token == 2: + break + res_tokens.append(token) + if detok == "<0x0A>": + if cli: + print("\n", end="", flush=True) + else: + if cli: + print(f"{detok}", end=" ", flush=True) + + if len(res_tokens) % 3 == 0: + part_str = self.decode_tokens(res_tokens) + yield part_str + + if self.device == "cuda": + del sec_vic, pkv, logits + torch.cuda.empty_cache() + gc.collect() + + res_str = self.decode_tokens(res_tokens) + # print(f"[DEBUG] final output : \n{res_str}") + yield res_str + + def generate_new_token(self, params, debug=False): + def forward_first(first_vic, prompt, cache_outputs=False): + input_ids = self.tokenizer(prompt).input_ids + input_id_len = len(input_ids) + input_ids = torch.tensor(input_ids) + input_ids = input_ids.reshape([1, input_id_len]) + firstVicunaInput = (input_ids,) + assert first_vic is not None + output_first_vicuna = first_vic("forward", firstVicunaInput) + output_first_vicuna_tensor = torch.tensor(output_first_vicuna[1:]) + logits_first_vicuna = torch.tensor(output_first_vicuna[0]) + if cache_outputs: + torch.save( + logits_first_vicuna, "logits_first_vicuna_tensor.pt" + ) + torch.save( + output_first_vicuna_tensor, "output_first_vicuna_tensor.pt" + ) + token = torch.argmax( + torch.tensor(logits_first_vicuna)[:, -1, :], dim=1 + ) + return token, logits_first_vicuna, output_first_vicuna_tensor + + def forward_second(sec_vic, inputs=None, load_inputs=False): + if inputs is not None: + logits = inputs[0] + pkv = inputs[1:] + elif load_inputs: + pkv = torch.load("output_first_vicuna_tensor.pt") + pkv = tuple(torch.tensor(x) for x in pkv) + logits = torch.load("logits_first_vicuna_tensor.pt") + else: + print( + "Either inputs must be given, or load_inputs must be true" + ) + return None + token = torch.argmax(torch.tensor(logits)[:, -1, :], dim=1) + token = token.to(torch.int64).reshape([1, 1]) + secondVicunaInput = (token,) + tuple(pkv) + + secondVicunaOutput = sec_vic("forward", secondVicunaInput) + new_pkv = secondVicunaOutput[1:] + new_logits = secondVicunaOutput[0] + new_token = torch.argmax(torch.tensor(new_logits)[:, -1, :], dim=1) + return new_token, new_logits, new_pkv + + is_first = params["is_first"] + + if is_first: + prompt = params["prompt"] + fv = params["fv"] + token, logits, pkv = forward_first( + fv, # self.shark_model[0], + prompt=prompt, + cache_outputs=False, + ) + else: + _logits = params["logits"] + _pkv = params["pkv"] + inputs = (_logits,) + tuple(_pkv) + sv = params["sv"] + token, logits, pkv = forward_second( + sv, # self.shark_model[1], + inputs=inputs, + load_inputs=False, + ) + + detok = self.tokenizer.decode(token) + if debug: + print( + f"[DEBUG] is_first: {is_first} |" + f" token : {token} | detok : {detok}" + ) + ret_dict = { + "token": token, + "logits": logits, + "pkv": pkv, + "detok": detok, + } + return ret_dict + + def autocomplete(self, prompt): + # use First vic alone to complete a story / prompt / sentence. + pass + if __name__ == "__main__": args, unknown = parser.parse_known_args() @@ -91,7 +1451,7 @@ else Path(args.second_vicuna_vmfb_path) ) - vic = vp.Vicuna( + vic = UnshardedVicuna( "vicuna", device=args.device, precision=args.precision, @@ -100,6 +1460,7 @@ first_vicuna_vmfb_path=first_vic_vmfb_path, second_vicuna_vmfb_path=second_vic_vmfb_path, load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, + weight_group_size=args.weight_group_size, ) else: if args.config is not None: @@ -108,11 +1469,12 @@ config_file.close() else: config_json = None - vic = vsp.Vicuna( + vic = ShardedVicuna( "vicuna", device=args.device, precision=args.precision, config_json=config_json, + weight_group_size=args.weight_group_size, ) prompt_history = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n" prologue_prompt = "ASSISTANT:\n" diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py index 1753c13e77..3f84aeb2ae 100644 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ b/apps/language_models/src/model_wrappers/vicuna_model.py @@ -1,14 +1,39 @@ import torch from transformers import AutoModelForCausalLM +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + class FirstVicuna(torch.nn.Module): - def __init__(self, model_path): + def __init__(self, model_path, precision="fp32", weight_group_size=128): super().__init__() kwargs = {"torch_dtype": torch.float32} self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + if precision in ["int4", "int8"]: + print("First Vicuna applying weight quantization..") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + get_model_impl(self.model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_type="float", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") def forward(self, input_ids): op = self.model(input_ids=input_ids, use_cache=True) @@ -22,12 +47,34 @@ def forward(self, input_ids): class SecondVicuna(torch.nn.Module): - def __init__(self, model_path): + def __init__(self, model_path, precision="fp32", weight_group_size=128): super().__init__() kwargs = {"torch_dtype": torch.float32} self.model = AutoModelForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) + if precision in ["int4", "int8"]: + print("Second Vicuna applying weight quantization..") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + get_model_impl(self.model).layers, + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_type="float", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") def forward( self, diff --git a/setup_venv.sh b/setup_venv.sh index e80560607b..c74be6b4d8 100755 --- a/setup_venv.sh +++ b/setup_venv.sh @@ -159,3 +159,6 @@ if [[ -z "${CONDA_PREFIX}" && "$SKIP_VENV" != "1" ]]; then echo "${Green}Before running examples activate venv with:" echo " ${Green}source $VENV_DIR/bin/activate" fi + +$PYTHON -m pip install brevitas +export PYTHONPATH=`pwd`/third_party/brevitas/src:$PYTHONPATH diff --git a/shark/shark_importer.py b/shark/shark_importer.py index e12f7c0922..902f946961 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -491,6 +491,7 @@ def import_with_fx( model, inputs, is_f16=False, + precision="fp32", f16_input_mask=None, debug=False, training=False, @@ -504,6 +505,24 @@ def import_with_fx( import torch from torch.fx.experimental.proxy_tensor import make_fx from torch._decomp import get_decompositions + from typing import List + + from brevitas_examples.llm.llm_quant.export import ( + block_quant_layer_level_manager, + ) + from brevitas_examples.llm.llm_quant.export import ( + brevitas_layer_export_mode, + ) + from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import ( + LinearWeightBlockQuantHandlerFwd, + ) + from brevitas_examples.llm.llm_quant.export import replace_call_fn_target + from brevitas_examples.llm.llm_quant.sharded_mlir_group_export import ( + matmul_rhs_group_quant_placeholder, + ) + from brevitas.backport.fx.experimental.proxy_tensor import ( + make_fx as brevitas_make_fx, + ) golden_values = None if debug: @@ -511,26 +530,124 @@ def import_with_fx( golden_values = model(*inputs) except: golden_values = None + + def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: + removed_indexes = [] + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, (list, tuple)): + node_arg = list(node_arg) + node_args_len = len(node_arg) + for i in range(node_args_len): + curr_index = node_args_len - (i + 1) + if node_arg[curr_index] is None: + removed_indexes.append(curr_index) + node_arg.pop(curr_index) + node.args = (tuple(node_arg),) + break + + if len(removed_indexes) > 0: + fx_g.graph.lint() + fx_g.graph.eliminate_dead_code() + fx_g.recompile() + removed_indexes.sort() + return removed_indexes + + def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: + """ + Replace tuple with tuple element in functions that return one-element tuples. + Returns true if an unwrapping took place, and false otherwise. + """ + unwrapped_tuple = False + for node in fx_g.graph.nodes: + if node.op == "output": + assert ( + len(node.args) == 1 + ), "Output node must have a single argument" + node_arg = node.args[0] + if isinstance(node_arg, tuple): + if len(node_arg) == 1: + node.args = (node_arg[0],) + unwrapped_tuple = True + break + + if unwrapped_tuple: + fx_g.graph.lint() + fx_g.recompile() + return unwrapped_tuple + + def transform_fx(fx_g): + for node in fx_g.graph.nodes: + if node.op == "call_function": + if node.target in [torch.ops.aten.empty]: + # aten.empty should be filled with zeros. + with fx_g.graph.inserting_after(node): + new_node = fx_g.graph.call_function( + torch.ops.aten.zero_, args=(node,) + ) + node.append(new_node) + node.replace_all_uses_with(new_node) + new_node.args = (node,) + fx_g.graph.lint() + # TODO: Control the decompositions. - fx_g = make_fx( - model, - decomposition_table=get_decompositions( - [ - torch.ops.aten.embedding_dense_backward, - torch.ops.aten.native_layer_norm_backward, - torch.ops.aten.slice_backward, - torch.ops.aten.select_backward, - torch.ops.aten.norm.ScalarOpt_dim, - torch.ops.aten.native_group_norm, - torch.ops.aten.upsample_bilinear2d.vec, - torch.ops.aten.split.Tensor, - torch.ops.aten.split_with_sizes, - torch.ops.aten.native_layer_norm, - torch.ops.aten.masked_fill.Tensor, - torch.ops.aten.masked_fill.Scalar, - ] - ), - )(*inputs) + if precision in ["int4", "int8"]: + export_context_manager = brevitas_layer_export_mode + export_class = block_quant_layer_level_manager( + export_handlers=[LinearWeightBlockQuantHandlerFwd] + ) + with export_context_manager(model, export_class): + fx_g = brevitas_make_fx( + model, + decomposition_table=get_decompositions( + [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + ] + ), + )(*inputs) + + transform_fx(fx_g) + replace_call_fn_target( + fx_g, + src=matmul_rhs_group_quant_placeholder, + target=torch.ops.brevitas.matmul_rhs_group_quant, + ) + + fx_g.recompile() + removed_none_indexes = _remove_nones(fx_g) + was_unwrapped = _unwrap_single_tuple_return(fx_g) + else: + fx_g = make_fx( + model, + decomposition_table=get_decompositions( + [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + torch.ops.aten.native_layer_norm, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar, + ] + ), + )(*inputs) fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) fx_g.recompile()