From 9c50edc664a7e2d78fac46ed9526b01c690fad10 Mon Sep 17 00:00:00 2001 From: Eliasj42 <46754803+Eliasj42@users.noreply.github.com> Date: Mon, 4 Dec 2023 09:11:52 -0800 Subject: [PATCH] fixed functionality of sharded vicuna/llama2 (#1982) Co-authored-by: Elias Joseph --- apps/language_models/scripts/vicuna.py | 556 ++++++++++++++++++------- 1 file changed, 398 insertions(+), 158 deletions(-) diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py index 72a36a6741..7280da73c2 100644 --- a/apps/language_models/scripts/vicuna.py +++ b/apps/language_models/scripts/vicuna.py @@ -1,10 +1,8 @@ import argparse -from dataclasses import dataclass import json import re import gc from io import BytesIO -from os import environ from pathlib import Path from statistics import mean, stdev from tqdm import tqdm @@ -12,6 +10,8 @@ import subprocess import sys import time +from dataclasses import dataclass +from os import environ import torch import torch_mlir @@ -108,6 +108,11 @@ default=128, help="Group size for per_group weight quantization. Default: 128.", ) + +parser.add_argument( + "--n_devices", type=int, default=None, help="Number of GPUs to use" +) + parser.add_argument( "--download_vmfb", default=False, @@ -141,15 +146,9 @@ ) parser.add_argument( "--Xiree_compile", - action='append', + action="append", default=[], - help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments." -) -parser.add_argument( - "--enable_tracing", - default=False, - action=argparse.BooleanOptionalAction, - help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token." + help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments.", ) # Microbenchmarking options. @@ -177,6 +176,12 @@ default="", help="Specify the system prompt. This is only used with `--enable_microbenchmark`", ) +parser.add_argument( + "--enable_tracing", + default=False, + action=argparse.BooleanOptionalAction, + help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token." +) parser.add_argument( "--user_prompt", type=str, @@ -436,7 +441,9 @@ def generate_new_token(self, params, sharded=True, cli=True): if sharded: output = self.shark_model.forward(input_ids, is_first=is_first) else: - output = self.shark_model("first_vicuna_forward", (input_ids,), send_to_host=False) + output = self.shark_model( + "first_vicuna_forward", (input_ids,), send_to_host=False + ) else: token = params["token"] @@ -491,6 +498,7 @@ def __init__( self, model_name, hf_model_path="TheBloke/vicuna-7B-1.1-HF", + hf_auth_token=None, max_num_tokens=512, device="cuda", precision="fp32", @@ -499,7 +507,9 @@ def __init__( compressed=False, extra_args_cmd=[], debug=False, + n_devices=None, ) -> None: + self.hf_auth_token = hf_auth_token super().__init__( model_name, hf_model_path, @@ -514,14 +524,17 @@ def __init__( self.config = config_json self.weight_group_size = weight_group_size self.compressed = compressed + self.n_devices = n_devices + self.dir_name = f"{model_name}-{precision}-{device}-models" + self.dir_path = Path(self.dir_name) + if not self.dir_path.is_dir(): + self.dir_path.mkdir(parents=True, exist_ok=True) self.shark_model = self.compile(device=device) def get_tokenizer(self): kwargs = {} - if self.model_name == "llama2": - kwargs = { - "use_auth_token": "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" - } + if "llama2" in self.model_name: + kwargs = {"use_auth_token": self.hf_auth_token} tokenizer = AutoTokenizer.from_pretrained( self.hf_model_path, use_fast=False, @@ -532,8 +545,8 @@ def get_tokenizer(self): def get_src_model(self): # Retrieve the torch model from Huggingface kwargs = {"torch_dtype": torch.float} - if self.model_name == "llama2": - kwargs["use_auth_token"] = "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" + if "llama2" in self.model_name: + kwargs["use_auth_token"] = self.hf_auth_token vicuna_model = AutoModelForCausalLM.from_pretrained( self.hf_model_path, **kwargs, @@ -560,13 +573,17 @@ def write_in_dynamic_inputs0(self, module, dynamic_input_size): return new_module def write_in_dynamic_inputs1(self, module, dynamic_input_size): + if self.precision == "fp32": + fprecision = "32" + else: + fprecision = "16" new_lines = [] for line in module.splitlines(): if "dim_42 =" in line: continue if f"%c{dynamic_input_size}_i64 =" in line: new_lines.append( - "%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf32>" + f"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf{fprecision}>" ) new_lines.append( f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" @@ -611,9 +628,11 @@ def compile_vicuna_layer( past_key_value0, past_key_value1, ) + is_f16 = self.precision in ["fp16", "int4"] mlir_bytecode = import_with_fx( vicuna_layer, model_inputs, + is_f16=is_f16, precision=self.precision, f16_input_mask=[False, False], mlir_type="torchscript", @@ -664,9 +683,11 @@ def compile_vicuna_layer4( pkv70, pkv71, ) + is_f16 = self.precision in ["fp16", "int4"] mlir_bytecode = import_with_fx( vicuna_layer, model_inputs, + is_f16=is_f16, precision=self.precision, f16_input_mask=[False, False], mlir_type="torchscript", @@ -691,38 +712,82 @@ def get_device_index(self, layer_string): return device_idx def compile_lmhead( - self, lmh, hidden_states, device="cpu", device_idx=None, + self, + lmh, + hidden_states, + device="cpu", + device_idx=None, ): # compile the lm head of the vicuna model # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"lmhead.mlir") - vmfb_path = Path(f"lmhead.vmfb") + mlir_path = Path(f"{self.dir_name}/lmhead.mlir") + vmfb_path = Path(f"{self.dir_name}/lmhead.vmfb") if mlir_path.exists(): print(f"Found bytecode module at {mlir_path}.") else: - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) + # hidden_states = torch_mlir.TensorPlaceholder.like( + # hidden_states, dynamic_axes=[1] + # ) + + is_f16 = self.precision in ["fp16", "int4"] + if is_f16: + ts_graph = import_with_fx( + lmh, + (hidden_states,), + is_f16=is_f16, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + + if is_f16: + hidden_states = hidden_states.to(torch.float16) + + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + ts_graph, + (hidden_states,), + output_type="torch", + backend_legal_ops=["quant.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) - # module = torch_mlir.compile( - # lmh, - # (hidden_states,), - # torch_mlir.OutputType.LINALG_ON_TENSORS, - # use_tracing=False, - # verbose=False, + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + module = torch_mlir.compile( + lmh, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + filepath = Path(f"{self.dir_name}/lmhead.mlir") + # download_public_file( + # "gs://shark_tank/elias/compressed_sv/lmhead.mlir", + # filepath.absolute(), + # single_file=True, # ) - # bytecode_stream = BytesIO() - # module.operation.write_bytecode(bytecode_stream) - # bytecode = bytecode_stream.getvalue() - # f_ = open(mlir_path, "wb") - # f_.write(bytecode) - # f_.close() - filepath = Path("lmhead.mlir") - download_public_file( - "gs://shark_tank/elias/compressed_sv/lmhead.mlir", - filepath.absolute(), - single_file=True, - ) mlir_path = filepath shark_module = SharkInference( @@ -735,7 +800,9 @@ def compile_lmhead( if vmfb_path.exists(): shark_module.load_module(vmfb_path) else: - shark_module.save_module(module_name="lmhead", debug=self.debug) + shark_module.save_module( + module_name=f"{self.dir_name}/lmhead", debug=self.debug + ) shark_module.load_module(vmfb_path) compiled_module = LMHeadCompiled(shark_module) return compiled_module @@ -743,28 +810,72 @@ def compile_lmhead( def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): # compile the normalization layer of the vicuna model # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"norm.mlir") - vmfb_path = Path(f"norm.vmfb") + mlir_path = Path(f"{self.dir_name}/norm.mlir") + vmfb_path = Path(f"{self.dir_name}/norm.vmfb") if mlir_path.exists(): print(f"Found bytecode module at {mlir_path}.") else: - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - - # module = torch_mlir.compile( - # fvn, - # (hidden_states,), - # torch_mlir.OutputType.LINALG_ON_TENSORS, - # use_tracing=False, - # verbose=False, + # hidden_states = torch_mlir.TensorPlaceholder.like( + # hidden_states, dynamic_axes=[1] # ) - filepath = Path("norm.mlir") - download_public_file( - "gs://shark_tank/elias/compressed_sv/norm.mlir", - filepath.absolute(), - single_file=True, + + is_f16 = self.precision in ["fp16", "int4"] + if is_f16: + ts_graph = import_with_fx( + fvn, + (hidden_states,), + is_f16=is_f16, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + + if is_f16: + hidden_states = hidden_states.to(torch.float16) + + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + + module = torch_mlir.compile( + ts_graph, + (hidden_states,), + output_type="torch", + backend_legal_ops=["quant.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + else: + hidden_states = torch_mlir.TensorPlaceholder.like( + hidden_states, dynamic_axes=[1] + ) + module = torch_mlir.compile( + fvn, + (hidden_states,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + filepath = Path(f"{self.dir_name}/norm.mlir") + # download_public_file( + # "gs://shark_tank/elias/compressed_sv/norm.mlir", + # filepath.absolute(), + # single_file=True, + # ) mlir_path = filepath shark_module = SharkInference( @@ -777,7 +888,9 @@ def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): if vmfb_path.exists(): shark_module.load_module(vmfb_path) else: - shark_module.save_module(module_name="norm", debug=self.debug) + shark_module.save_module( + module_name=f"{self.dir_name}/norm", debug=self.debug + ) shark_module.load_module(vmfb_path) compiled_module = VicunaNormCompiled(shark_module) return compiled_module @@ -785,33 +898,65 @@ def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): # compile the embedding layer of the vicuna model # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"embedding.mlir") - vmfb_path = Path(f"embedding.vmfb") + mlir_path = Path(f"{self.dir_name}/embedding.mlir") + vmfb_path = Path(f"{self.dir_name}/embedding.vmfb") if mlir_path.exists(): print(f"Found bytecode module at {mlir_path}.") else: - input_ids = torch_mlir.TensorPlaceholder.like( - input_ids, dynamic_axes=[1] + is_f16 = self.precision in ["fp16", "int4"] + if is_f16: + # input_ids = torch_mlir.TensorPlaceholder.like( + # input_ids, dynamic_axes=[1] + # ) + ts_graph = import_with_fx( + fve, + (input_ids,), + is_f16=is_f16, + precision=self.precision, + f16_input_mask=[False, False], + mlir_type="torchscript", + ) + input_ids = torch_mlir.TensorPlaceholder.like( + input_ids, dynamic_axes=[1] + ) + module = torch_mlir.compile( + ts_graph, + (input_ids,), + output_type="torch", + backend_legal_ops=["quant.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + else: + input_ids = torch_mlir.TensorPlaceholder.like( + input_ids, dynamic_axes=[1] + ) + module = torch_mlir.compile( + fve, + (input_ids,), + torch_mlir.OutputType.LINALG_ON_TENSORS, + use_tracing=False, + verbose=False, + ) + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + module, + "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", ) - # module = torch_mlir.compile( - # fve, - # (input_ids,), - # torch_mlir.OutputType.LINALG_ON_TENSORS, - # use_tracing=False, - # verbose=False, + bytecode_stream = BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + f_ = open(mlir_path, "wb") + f_.write(bytecode) + f_.close() + filepath = Path(f"{self.dir_name}/embedding.mlir") + # download_public_file( + # "gs://shark_tank/elias/compressed_sv/embedding.mlir", + # filepath.absolute(), + # single_file=True, # ) - # bytecode_stream = BytesIO() - # module.operation.write_bytecode(bytecode_stream) - # bytecode = bytecode_stream.getvalue() - # f_ = open(mlir_path, "wb") - # f_.write(bytecode) - # f_.close() - filepath = Path("embedding.mlir") - download_public_file( - "gs://shark_tank/elias/compressed_sv/embedding.mlir", - filepath.absolute(), - single_file=True, - ) mlir_path = filepath shark_module = SharkInference( @@ -824,20 +969,36 @@ def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): if vmfb_path.exists(): shark_module.load_module(vmfb_path) else: - shark_module.save_module(module_name="embedding", debug=self.debug) + shark_module.save_module( + module_name=f"{self.dir_name}/embedding", debug=self.debug + ) shark_module.load_module(vmfb_path) compiled_module = VicunaEmbeddingCompiled(shark_module) return compiled_module def compile_to_vmfb_one_model( - self, inputs0, layers0, inputs1, layers1, device="cpu", + self, + inputs0, + layers0, + inputs1, + layers1, + device="cpu", ): + if self.precision != "fp32": + inputs0 = tuple( + inpt.to(torch.float16) if inpt.dtype == torch.float32 else inpt + for inpt in inputs0 + ) + inputs1 = tuple( + inpt.to(torch.float16) if inpt.dtype == torch.float32 else inpt + for inpt in inputs1 + ) mlirs, modules = [], [] assert len(layers0) == len(layers1) for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): - mlir_path = Path(f"{idx}_full.mlir") - vmfb_path = Path(f"{idx}_full.vmfb") + mlir_path = Path(f"{self.dir_name}/{idx}_full.mlir") + vmfb_path = Path(f"{self.dir_name}/{idx}_full.vmfb") # if vmfb_path.exists(): # continue if mlir_path.exists(): @@ -876,8 +1037,38 @@ def compile_to_vmfb_one_model( layer0, inputs0[0], inputs0[1], inputs0[2] ) if self.precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import quantize_model - from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + from brevitas_examples.common.generative.quantize import ( + quantize_model, + ) + from brevitas_examples.llm.llm_quant.run_utils import ( + get_model_impl, + ) + + hidden_states_placeholder0 = TensorPlaceholder.like( + inputs0[0], dynamic_axes=[1] + ) + attention_mask_placeholder0 = TensorPlaceholder.like( + inputs0[1], dynamic_axes=[3] + ) + position_ids_placeholder0 = TensorPlaceholder.like( + inputs0[2], dynamic_axes=[1] + ) + hidden_states_placeholder1 = TensorPlaceholder.like( + inputs1[0], dynamic_axes=[1] + ) + attention_mask_placeholder1 = TensorPlaceholder.like( + inputs1[1], dynamic_axes=[3] + ) + position_ids_placeholder1 = TensorPlaceholder.like( + inputs1[2], dynamic_axes=[1] + ) + pkv0_placeholder = TensorPlaceholder.like( + inputs1[3], dynamic_axes=[2] + ) + pkv1_placeholder = TensorPlaceholder.like( + inputs1[4], dynamic_axes=[2] + ) + module0 = torch_mlir.compile( ts_g, ( @@ -891,6 +1082,7 @@ def compile_to_vmfb_one_model( use_tracing=False, verbose=False, ) + print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( module0, @@ -958,7 +1150,7 @@ def compile_to_vmfb_one_model( module1 = self.write_in_dynamic_inputs1(str(module1), 138) module_combined = self.combine_mlir_scripts( - module0, module1, f"{idx}_full.mlir" + module0, module1, f"{self.dir_name}/{idx}_full.mlir" ) mlirs.append(module_combined) @@ -966,6 +1158,11 @@ def compile_to_vmfb_one_model( device_idx = self.get_device_index( f"first_vicuna.model.model.layers.{idx}[\s.$]" ) + if device_idx is None: + if self.n_devices is not None: + device_idx = idx % self.n_devices + else: + device_idx = None module = SharkInference( None, device=device, @@ -979,6 +1176,11 @@ def compile_to_vmfb_one_model( device_idx = self.get_device_index( f"first_vicuna.model.model.layers.{idx}[\s.$]" ) + if device_idx is None: + if self.n_devices is not None: + device_idx = idx % self.n_devices + else: + device_idx = 0 module = SharkInference( mlirs[idx], device=device, @@ -987,7 +1189,7 @@ def compile_to_vmfb_one_model( mmap=False, ) module.save_module( - module_name=f"{idx}_full", + module_name=f"{self.dir_name}/{idx}_full", extra_args=[ "--iree-vm-target-truncate-unsupported-floats", "--iree-codegen-check-ir-before-llvm-conversion=false", @@ -1032,10 +1234,15 @@ def compile_to_vmfb_one_model4( device_idx = self.get_device_index( f"first_vicuna.model.model.layers.{idx}[\s.$]" ) + if device_idx is None: + if self.n_devices is not None: + device_idx = idx % self.n_devices + else: + device_idx = 0 module = SharkInference( None, device=device, - device_idx=0, + device_idx=device_idx, mlir_dialect="tm_tensor", mmap=True, ) @@ -1045,10 +1252,15 @@ def compile_to_vmfb_one_model4( device_idx = self.get_device_index( f"first_vicuna.model.model.layers.{idx}[\s.$]" ) + if device_idx is None: + if self.n_devices is not None: + device_idx = idx % self.n_devices + else: + device_idx = 0 module = SharkInference( mlirs[idx], device=device, - device_idx=0, + device_idx=device_idx, mlir_dialect="tm_tensor", mmap=True, ) @@ -1077,8 +1289,13 @@ def get_sharded_model(self, device="cpu", compressed=False): ) if self.precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import quantize_model - from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + from brevitas_examples.common.generative.quantize import ( + quantize_model, + ) + from brevitas_examples.llm.llm_quant.run_utils import ( + get_model_impl, + ) + print("Applying weight quantization..") weight_bit_width = 4 if self.precision == "int4" else 8 quantize_model( @@ -1186,7 +1403,7 @@ def get_sharded_model(self, device="cpu", compressed=False): layers0 = [layers00, layers01, layers02, layers03] layers1 = [layers10, layers11, layers12, layers13] - _, modules = self.compile_to_vmfb_one_model4( + _, modules = self.compile_to_vmfb_one_model( placeholder_input0, layers0, placeholder_input1, @@ -1213,9 +1430,6 @@ def compile(self, device="cpu"): return self.get_sharded_model( device=device, compressed=self.compressed ) - return self.get_sharded_model( - device=device, compressed=self.compressed - ) def generate(self, prompt, cli=False): # TODO: refactor for cleaner integration @@ -1234,13 +1448,17 @@ def generate(self, prompt, cli=False): "past_key_values": _past_key_values, } + decode_st_time = time.time() + generated_token_op = self.generate_new_token(params=params) + prefill_time = time.time() - decode_st_time + _token = generated_token_op["token"] _past_key_values = generated_token_op["past_key_values"] _detok = generated_token_op["detok"] history.append(_token) - yield self.tokenizer.decode(history) + yield self.tokenizer.decode(history), None, prefill_time if _token == 2: break @@ -1251,7 +1469,7 @@ def generate(self, prompt, cli=False): if type(tokens_generated[i]) != int: tokens_generated[i] = int(tokens_generated[i][0]) result_output = self.tokenizer.decode(tokens_generated) - yield result_output + yield result_output, "formatted", None def autocomplete(self, prompt): # use First vic alone to complete a story / prompt / sentence. @@ -1308,9 +1526,11 @@ def __init__( # Sanity check for device, device_id pair if "://" in device: if device_id is not None: - print("[ERR] can't have both full device path and a device id.\n" - f"Device : {device} | device_id : {device_id}\n" - "proceeding with given Device ignoring device_id") + print( + "[ERR] can't have both full device path and a device id.\n" + f"Device : {device} | device_id : {device_id}\n" + "proceeding with given Device ignoring device_id" + ) self.device, self.device_id = device.split("://") if len(self.device_id) < 2: self.device_id = int(self.device_id) @@ -1342,13 +1562,19 @@ def get_model_path(self, suffix="mlir"): target_triple = "" if self.vulkan_target_triple != "": target_triple = "_" - target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1]) + target_triple += "_".join( + self.vulkan_target_triple.split("-")[:-1] + ) differentiator = target_triple elif "rocm" == self.device: from shark.iree_utils.gpu_utils import get_rocm_device_arch - device_arch = get_rocm_device_arch(self.device_id if self.device_id is not None else 0, self.extra_args) - differentiator = '_' + device_arch + + device_arch = get_rocm_device_arch( + self.device_id if self.device_id is not None else 0, + self.extra_args, + ) + differentiator = "_" + device_arch return Path( f"{self.model_name}_{self.precision}_{safe_device}{differentiator}.{suffix}" @@ -1502,9 +1728,15 @@ def compile(self): mlir_generated = False for suffix in ["mlirbc", "mlir"]: self.vicuna_mlir_path = self.get_model_path(suffix) - if "cpu" in self.device and "llama2_7b" in self.vicuna_mlir_path.name: + if ( + "cpu" in self.device + and "llama2_7b" in self.vicuna_mlir_path.name + ): self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir") - if not self.vicuna_mlir_path.exists() and self.load_mlir_from_shark_tank: + if ( + not self.vicuna_mlir_path.exists() + and self.load_mlir_from_shark_tank + ): print( f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}" ) @@ -1514,7 +1746,9 @@ def compile(self): single_file=True, ) if self.vicuna_mlir_path.exists(): - print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}") + print( + f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}" + ) combined_module = self.vicuna_mlir_path.absolute() mlir_generated = True break @@ -1546,7 +1780,7 @@ def compile(self): model = FirstVicuna( self.hf_model_path, self.precision, - "fp32" if self.device=="cpu" else "fp16", + "fp32" if self.device == "cpu" else "fp16", self.weight_group_size, self.model_name, self.hf_auth_token, @@ -1555,7 +1789,7 @@ def compile(self): model = FirstVicunaGPU( self.hf_model_path, self.precision, - "fp32" if self.device=="cpu" else "fp16", + "fp32" if self.device == "cpu" else "fp16", self.weight_group_size, self.model_name, self.hf_auth_token, @@ -1572,9 +1806,7 @@ def compile(self): ) del model firstVicunaCompileInput = list(firstVicunaCompileInput) - firstVicunaCompileInput[ - 0 - ] = torch_mlir.TensorPlaceholder.like( + firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like( firstVicunaCompileInput[0], dynamic_axes=[1] ) @@ -1592,7 +1824,9 @@ def compile(self): verbose=False, ) if self.cache_vicunas: - with open(first_model_path[:-5]+"_torch.mlir", "w+") as f: + with open( + first_model_path[:-5] + "_torch.mlir", "w+" + ) as f: f.write(str(first_module)) print(f"[DEBUG] converting torch to linalg") run_pipeline_with_repro_report( @@ -1624,16 +1858,16 @@ def compile(self): print("Finished writing IR after dynamic") print(f"[DEBUG] Starting generation of second llama") - second_model_path = f"second_{self.model_name}_{self.precision}.mlir" + second_model_path = ( + f"second_{self.model_name}_{self.precision}.mlir" + ) if Path(second_model_path).exists(): print(f"loading {second_model_path}") with open(Path(second_model_path), "r") as f: second_module = f.read() else: # generate second vicuna - compilation_input_ids = torch.zeros( - [1, 1], dtype=torch.int64 - ) + compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) if self.model_name == "llama2_13b": dim1 = 40 total_tuple = 80 @@ -1724,7 +1958,9 @@ def compile(self): secondVicunaCompileInput = list(secondVicunaCompileInput) for i in range(len(secondVicunaCompileInput)): if i != 0: - secondVicunaCompileInput[i] = torch_mlir.TensorPlaceholder.like( + secondVicunaCompileInput[ + i + ] = torch_mlir.TensorPlaceholder.like( secondVicunaCompileInput[i], dynamic_axes=[2] ) secondVicunaCompileInput = tuple(secondVicunaCompileInput) @@ -1741,7 +1977,9 @@ def compile(self): ) print(f"[DEBUG] converting torch to linalg") if self.cache_vicunas: - with open(second_model_path[:-5]+"_torch.mlir", "w+") as f: + with open( + second_model_path[:-5] + "_torch.mlir", "w+" + ) as f: f.write(str(second_module)) run_pipeline_with_repro_report( second_module, @@ -1784,13 +2022,15 @@ def compile(self): ) del first_module, second_module - print(f"Compiling for device : {self.device}" - f"{'://' + str(self.device_id) if self.device_id is not None else ''}") + print( + f"Compiling for device : {self.device}" + f"{'://' + str(self.device_id) if self.device_id is not None else ''}" + ) shark_module = SharkInference( mlir_module=combined_module, device=self.device, mlir_dialect="tm_tensor", - device_idx=self.device_id + device_idx=self.device_id, ) path = shark_module.save_module( self.vicuna_vmfb_path.parent.absolute(), @@ -1812,9 +2052,7 @@ def decode_tokens(self, res_tokens): if type(res_tokens[i]) != int: res_tokens[i] = int(res_tokens[i][0]) - res_str = self.tokenizer.decode( - res_tokens, skip_special_tokens=False - ) + res_str = self.tokenizer.decode(res_tokens, skip_special_tokens=False) return res_str def generate(self, prompt, cli): @@ -1855,7 +2093,7 @@ def generate(self, prompt, cli): generated_token_op = self.generate_new_token( params=params, sharded=False, cli=cli ) - decode_time_ms = (time.time() - decode_st_time)*1000 + decode_time_ms = (time.time() - decode_st_time) * 1000 token = generated_token_op["token"] if "cpu" not in self.device: @@ -1944,7 +2182,6 @@ def create_prompt(model_name, history): msg = msg.strip() return msg - def miliseconds_to_seconds(ms: float) -> float: return ms / 1000.0 @@ -1994,6 +2231,15 @@ def print(self) -> None: print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s") print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s") print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)") + + def enable_tracy_tracing(): + # Make tracy wait for a caputre to be collected before exiting. + environ["TRACY_NO_EXIT"] = "1" + + if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy": + print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr) + print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr) + sys.exit(1) def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None: @@ -2026,26 +2272,21 @@ def avg_and_stdev(data): print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)") -def enable_tracy_tracing(): - # Make tracy wait for a caputre to be collected before exiting. - environ["TRACY_NO_EXIT"] = "1" - - if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy": - print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr) - print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr) - sys.exit(1) - - if __name__ == "__main__": args, unknown = parser.parse_known_args() _extra_args = list(args.Xiree_compile) - device_id = None + model_list = { + "vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF", + "llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf", + "llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf", + "llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf", + } + device_id = None if args.enable_tracing: enable_tracy_tracing() - # Process vulkan target triple. # TODO: This feature should just be in a common utils for other LLMs and in general # any model run via SHARK for Vulkan backend. @@ -2057,8 +2298,9 @@ def enable_tracy_tracing(): # Step 1. Fetch the device ID. from shark.iree_utils.vulkan_utils import ( get_all_vulkan_devices, - get_vulkan_target_triple + get_vulkan_target_triple, ) + vulkaninfo_list = get_all_vulkan_devices() id = 0 for device in vulkaninfo_list: @@ -2068,7 +2310,9 @@ def enable_tracy_tracing(): break id += 1 - assert device_id, f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" + assert ( + device_id + ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" # Step 2. Add a few flags targetting specific hardwares. if "rdna" in vulkan_target_triple: flags_to_add = [ @@ -2076,7 +2320,6 @@ def enable_tracy_tracing(): ] _extra_args = _extra_args + flags_to_add - vic = None if not args.sharded: vic_mlir_path = ( @@ -2109,7 +2352,7 @@ def enable_tracy_tracing(): download_vmfb=args.download_vmfb, cache_vicunas=args.cache_vicunas, extra_args_cmd=_extra_args, - device_id=device_id + device_id=device_id, ) else: if args.config is not None: @@ -2118,28 +2361,29 @@ def enable_tracy_tracing(): config_file.close() else: config_json = None + + print( + f"[DEBUG]: model_name_input = {model_list[args.model_name].split('=>')[1]}" + ) vic = ShardedVicuna( model_name=args.model_name, + hf_model_path=model_list[args.model_name].split("=>")[1], + hf_auth_token=args.hf_auth_token, device=args.device, precision=args.precision, config_json=config_json, weight_group_size=args.weight_group_size, extra_args_cmd=_extra_args, + n_devices=args.n_devices, ) history = [] - model_list = { - "vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF", - "llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf", - "llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf", - "llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf", - } - iteration = 0 benchmark_run_infos = [] + while True: # TODO: Add break condition from user input iteration += 1 @@ -2159,11 +2403,9 @@ def enable_tracy_tracing(): prefill_time_ms = 0 is_first = True token_times_ms = [] - for text, msg, exec_time in vic.generate(prompt, cli=True): if args.enable_tracing: vic.shark_model.shark_runner.iree_config.device.flush_profiling() - if msg is None: if is_first: # Note that the prefill time is in seconds, and all the decoded tokens in ms. @@ -2172,12 +2414,10 @@ def enable_tracy_tracing(): else: token_times_ms.append(exec_time) elif "formatted" in msg: - history[-1][1] = text print(f"\nResponse:\n{text.strip()}\n") run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms) run_info.print() benchmark_run_infos.append(run_info) - else: sys.exit( "unexpected message from the vicuna generate call, exiting." @@ -2185,4 +2425,4 @@ def enable_tracy_tracing(): if args.enable_microbenchmark: print("\n### Final Statistics ###") - print_aggregate_stats(benchmark_run_infos) + print_aggregate_stats(benchmark_run_infos) \ No newline at end of file