diff --git a/apps/language_models/src/model_wrappers/minigpt4.py b/apps/language_models/src/model_wrappers/minigpt4.py index 6521a0a5fe..f0cec5157d 100644 --- a/apps/language_models/src/model_wrappers/minigpt4.py +++ b/apps/language_models/src/model_wrappers/minigpt4.py @@ -2,9 +2,12 @@ import dataclasses from enum import auto, Enum from typing import List, Any -from transformers import StoppingCriteria, StoppingCriteriaList +from transformers import StoppingCriteria +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + class LayerNorm(torch.nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" @@ -15,10 +18,38 @@ def forward(self, x: torch.Tensor): class VisionModel(torch.nn.Module): - def __init__(self, ln_vision, visual_encoder): + def __init__(self, ln_vision, visual_encoder, precision="fp32", weight_group_size=128): super().__init__() self.ln_vision = ln_vision self.visual_encoder = visual_encoder + if precision in ["int4", "int8"]: + print("Vision Model applying weight quantization to ln_vision") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + self.ln_vision, + dtype=torch.float32, + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float", + weight_quant_type="asym", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + ) + print("Weight quantization applied.") + print("Vision Model applying weight quantization to visual_encoder") + quantize_model( + self.visual_encoder, + dtype=torch.float32, + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float", + weight_quant_type="asym", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + ) + print("Weight quantization applied.") def forward(self, image): image_embeds = self.ln_vision(self.visual_encoder(image)) @@ -41,10 +72,25 @@ def forward(self, query_tokens, image_embeds, image_atts): class FirstLlamaModel(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, precision="fp32", weight_group_size=128): super().__init__() self.model = model print("SHARK: Loading LLAMA Done") + if precision in ["int4", "int8"]: + print("First Llama applying weight quantization") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + self.model, + dtype=torch.float32, + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float", + weight_quant_type="asym", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + ) + print("Weight quantization applied.") def forward(self, inputs_embeds, position_ids, attention_mask): print("************************************") @@ -90,10 +136,25 @@ def forward(self, inputs_embeds, position_ids, attention_mask): class SecondLlamaModel(torch.nn.Module): - def __init__(self, model): + def __init__(self, model, precision="fp32", weight_group_size=128): super().__init__() self.model = model print("SHARK: Loading LLAMA Done") + if precision in ["int4", "int8"]: + print("Second Llama applying weight quantization") + weight_bit_width = 4 if precision == "int4" else 8 + quantize_model( + self.model, + dtype=torch.float32, + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float", + weight_quant_type="asym", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + ) + print("Weight quantization applied.") def forward( self, diff --git a/apps/language_models/src/pipelines/minigpt4_pipeline.py b/apps/language_models/src/pipelines/minigpt4_pipeline.py index 13b2e25824..7ad1359fcc 100644 --- a/apps/language_models/src/pipelines/minigpt4_pipeline.py +++ b/apps/language_models/src/pipelines/minigpt4_pipeline.py @@ -110,7 +110,6 @@ help="Maximum no. of new tokens that can be generated for a query", ) - def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" @@ -124,7 +123,179 @@ def is_url(input_url): is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None return is_url +import os +import tempfile +from shark.shark_inference import SharkInference +from shark.shark_importer import import_with_fx +import torch +import torch_mlir +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from typing import List, Tuple +from io import BytesIO +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + if len(lhs) == 3 and len(rhs) == 2: + return [lhs[0], lhs[1], rhs[0]] + elif len(lhs) == 2 and len(rhs) == 2: + return [lhs[0], rhs[0]] + else: + raise ValueError("Input shapes not supported.") + + +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: + # output dtype is the dtype of the lhs float input + lhs_rank, lhs_dtype = lhs_rank_dtype + return lhs_dtype + + +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: + return + + +brevitas_matmul_rhs_group_quant_library = [ + brevitas〇matmul_rhs_group_quant〡shape, + brevitas〇matmul_rhs_group_quant〡dtype, + brevitas〇matmul_rhs_group_quant〡has_value_semantics] + + +def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]): + vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") + shark_module = None + if os.path.isfile(vmfb_path): + shark_module = SharkInference( + None, + device=device, + mlir_dialect=mlir_dialect, + ) + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + return shark_module + +def compile_module( + shark_module, extended_model_name, generate_vmfb, extra_args=[] +): + if generate_vmfb: + vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") + if os.path.isfile(vmfb_path): + print(f"loading existing vmfb from: {vmfb_path}") + shark_module.load_module(vmfb_path, extra_args=extra_args) + else: + print( + "No vmfb found. Compiling and saving to {}".format(vmfb_path) + ) + path = shark_module.save_module( + os.getcwd(), extended_model_name, extra_args + ) + shark_module.load_module(path, extra_args=extra_args) + else: + shark_module.compile(extra_args) + return shark_module + + +def compile_int_precision(model, inputs, precision, device, generate_vmfb, extended_model_name): + torchscript_module = import_with_fx( + model, + inputs, + precision=precision, + mlir_type="torchscript", + ) + mlir_module = torch_mlir.compile( + torchscript_module, + inputs, + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + mlir_module, + "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + from contextlib import redirect_stdout + + mlir_file_path = os.path.join(os.getcwd(), f"{extended_model_name}_linalg.mlir") + with open(mlir_file_path, 'w') as f: + with redirect_stdout(f): + print(mlir_module.operation.get_asm()) + mlir_module = str(mlir_module) + mlir_module = mlir_module.encode("UTF-8") + mlir_module = BytesIO(mlir_module) + bytecode = mlir_module.read() + print(f"Elided IR written for {extended_model_name}") + return bytecode + shark_module = SharkInference( + mlir_module=bytecode, device=device, mlir_dialect="tm_tensor" + ) + extra_args = [ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ] + return ( + compile_module(shark_module, extended_model_name=extended_model_name, generate_vmfb=generate_vmfb, extra_args=extra_args), + bytecode + ) + +def shark_compile_through_fx_int( + model, + inputs, + extended_model_name, + precision, + f16_input_mask=None, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=None, + mlir_dialect="tm_tensor", +): + if generate_or_load_vmfb: + shark_module = load_vmfb( + extended_model_name=extended_model_name, + device=device, + mlir_dialect=mlir_dialect, + extra_args=extra_args, + ) + if shark_module: + return ( + shark_module, + None, + ) + + from shark.parser import shark_args + if "cuda" in device: + shark_args.enable_tf32 = True + + mlir_module = compile_int_precision(model, inputs, precision, device, generate_or_load_vmfb, extended_model_name) + extra_args = [ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ] + + shark_module = SharkInference( + mlir_module, + device=device, + mlir_dialect=mlir_dialect, + ) + return ( + compile_module( + shark_module, + extended_model_name, + generate_vmfb=generate_or_load_vmfb, + extra_args=extra_args, + ), + mlir_module, + ) + class MiniGPT4BaseModel(torch.nn.Module): @classmethod def from_config(cls, cfg): @@ -454,6 +625,12 @@ def download_dependencies(self): # Currently we're compiling VisionModel for fp32/cuda. def compile_vision_model(self): + # TODO: Hardcoding precision based on input choices. Take this down + # later. + vision_model_precision = "fp32" + if self.precision in ["int4", "int8", "fp16"]: + vision_model_precision = "fp16" + if not self._compile: vmfb = get_vmfb_from_path( self.vision_model_vmfb_path, self.device, "tm_tensor" @@ -464,7 +641,7 @@ def compile_vision_model(self): vmfb = get_vmfb_from_config( self.model_name, "vision_model", - self.precision, + vision_model_precision, self.device, self.vision_model_vmfb_path, ) @@ -472,28 +649,43 @@ def compile_vision_model(self): return vmfb visionModel = VisionModel( - self.model.ln_vision, self.model.visual_encoder + copy.deepcopy(self.model.ln_vision), copy.deepcopy(self.model.visual_encoder), vision_model_precision ) - extended_model_name = f"vision_model_{self.precision}_{self.device}" + extended_model_name = f"vision_model_{vision_model_precision}_{self.device}" print(f"Going to compile {extended_model_name}") # Inputs for VisionModel. inputs = [torch.randint(3, (1, 3, 224, 224), dtype=torch.float32)] is_f16 = False - if self.precision == "fp16": + if vision_model_precision == "fp16": is_f16 = True - shark_visionModel, _ = shark_compile_through_fx( - visionModel, - inputs, - extended_model_name=extended_model_name, - is_f16=is_f16, - f16_input_mask=None, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) + if self.precision in ["int4", "int8"]: + shark_visionModel, _ = shark_compile_through_fx_int( + visionModel, + inputs, + extended_model_name=extended_model_name, + precision=vision_model_precision, + f16_input_mask=None, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + else: + shark_visionModel, _ = shark_compile_through_fx( + visionModel, + inputs, + extended_model_name=extended_model_name, + precision=vision_model_precision, + f16_input_mask=None, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) print(f"Generated {extended_model_name}.vmfb") return shark_visionModel @@ -530,7 +722,7 @@ def compile_qformer_model(self): qformerBertModel, inputs, extended_model_name=extended_model_name, - is_f16=is_f16, + precision="fp32", f16_input_mask=f16_input_mask, save_dir=tempfile.gettempdir(), debug=False, @@ -567,7 +759,7 @@ def compile_first_llama(self, padding): return vmfb firstLlamaModel = FirstLlamaModel( - copy.deepcopy(self.model.llama_model) + copy.deepcopy(self.model.llama_model), self.precision ) extended_model_name = ( f"first_llama_{self.precision}_{self.device}_{padding}" @@ -583,19 +775,34 @@ def compile_first_llama(self, padding): if self.precision == "fp16": is_f16 = True f16_input_mask = [True, False, False] - shark_firstLlamaModel, _ = shark_compile_through_fx( - firstLlamaModel, - inputs, - extended_model_name=extended_model_name, - is_f16=is_f16, - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) + if self.precision in ["int4", "int8"]: + shark_firstLlamaModel, _ = shark_compile_through_fx_int( + firstLlamaModel, + inputs, + extended_model_name=extended_model_name, + precision=self.precision, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + else: + shark_firstLlamaModel, _ = shark_compile_through_fx( + firstLlamaModel, + inputs, + extended_model_name=extended_model_name, + precision=self.precision, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) print(f"Generated {extended_model_name}.vmfb") self.first_llama = shark_firstLlamaModel return shark_firstLlamaModel @@ -625,7 +832,7 @@ def compile_second_llama(self, padding): return vmfb secondLlamaModel = SecondLlamaModel( - copy.deepcopy(self.model.llama_model) + copy.deepcopy(self.model.llama_model), self.precision ) extended_model_name = ( f"second_llama_{self.precision}_{self.device}_{padding}" @@ -649,19 +856,34 @@ def compile_second_llama(self, padding): for i in past_key_value: f16_input_mask.append(True) - shark_secondLlamaModel, _ = shark_compile_through_fx( - secondLlamaModel, - inputs, - extended_model_name=extended_model_name, - is_f16=is_f16, - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) + if self.precision in ["int4", "int8"]: + shark_secondLlamaModel, _ = shark_compile_through_fx_int( + secondLlamaModel, + inputs, + extended_model_name=extended_model_name, + precision=self.precision, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) + else: + shark_secondLlamaModel, _ = shark_compile_through_fx( + secondLlamaModel, + inputs, + extended_model_name=extended_model_name, + precision=self.precision, + f16_input_mask=f16_input_mask, + save_dir=tempfile.gettempdir(), + debug=False, + generate_or_load_vmfb=True, + extra_args=[], + device=self.device, + mlir_dialect="tm_tensor", + ) print(f"Generated {extended_model_name}.vmfb") self.second_llama = shark_secondLlamaModel return shark_secondLlamaModel @@ -827,7 +1049,7 @@ def answer( ) i = 0 timesRan = 0 - is_fp16 = True + is_fp16 = self.precision == "fp16" llama_list = [] isPyTorchVariant = False while True: @@ -982,7 +1204,7 @@ def upload_img(self, image, conv, img_list): with self.model.maybe_autocast(): shark_visionModel = self.compile_vision_model() - if self.precision == "fp16": + if self.precision in ["int4", "int8", "fp16"]: image = image.to(torch.float16) image_embeds = shark_visionModel("forward", (image,)) # image_embeds = shark_visionModel.forward(image) @@ -995,8 +1217,6 @@ def upload_img(self, image, conv, img_list): query_tokens = self.model.query_tokens.expand( image_embeds.shape[0], -1, -1 ).to(device) - # if self.precision == "fp16": - # query_tokens = query_tokens.to(torch.float16) shark_QformerBertModel = self.compile_qformer_model() query_output = shark_QformerBertModel( "forward", @@ -1006,7 +1226,6 @@ def upload_img(self, image, conv, img_list): image_atts, ), ) - # query_output = shark_QformerBertModel.forward(query_tokens, image_embeds, image_atts) query_output = torch.from_numpy(query_output) inputs_llama = self.model.llama_proj(query_output) @@ -1132,13 +1351,16 @@ def get_context_emb(self, conv, img_list, max_allowed_tokens=200): ) sys.exit() + vision_model_precision = precision + if precision in ["int4", "int8"]: + vision_model_precision = "fp16" vision_model_vmfb_path = ( - Path("vision_model_fp16_cuda.vmfb") + Path(f"vision_model_{vision_model_precision}_{device}.vmfb") if args.vision_model_vmfb_path is None else Path(args.vision_model_vmfb_path) ) qformer_vmfb_path = ( - Path("qformer_fp32_cuda.vmfb") + Path(f"qformer_fp32_{device}.vmfb") if args.qformer_vmfb_path is None else Path(args.qformer_vmfb_path) ) diff --git a/apps/stable_diffusion/web/ui/minigpt4_ui.py b/apps/stable_diffusion/web/ui/minigpt4_ui.py index 1a49159cdc..bfe0641301 100644 --- a/apps/stable_diffusion/web/ui/minigpt4_ui.py +++ b/apps/stable_diffusion/web/ui/minigpt4_ui.py @@ -6,6 +6,7 @@ MiniGPT4, CONV_VISION, ) +from pathlib import Path chat = None @@ -27,19 +28,23 @@ def gradio_reset(chat_state, img_list): ) -def upload_img(gr_img, text_input, chat_state, device): +def upload_img(gr_img, text_input, chat_state, device, precision, _compile): global chat if chat is None: - from apps.language_models.src.pipelines.minigpt4_pipeline import ( - MiniGPT4, - ) - + vision_model_precision = precision + if precision in ["int4", "int8"]: + vision_model_precision = "fp16" + vision_model_vmfb_path = Path(f"vision_model_{vision_model_precision}_{device}.vmfb") + qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb") chat = MiniGPT4( model_name="MiniGPT4", hf_model_path=None, max_new_tokens=30, device=device, - precision="fp16", + precision=precision, + _compile=_compile, + vision_model_vmfb_path=vision_model_vmfb_path, + qformer_vmfb_path=qformer_vmfb_path, ) if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None @@ -141,10 +146,25 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): placeholder="Please upload your image first", interactive=False, ) + precision = gr.Radio( + label="Precision", + value="int8", + choices=[ + "int8", + "fp16", + "fp32", + ], + visible=True, + ) + _compile = gr.Checkbox( + value=False, + label="Compile", + interactive=True, + ) upload_button.click( upload_img, - [image, text_input, chat_state, device], + [image, text_input, chat_state, device, precision, _compile], [image, text_input, upload_button, chat_state, img_list], ) diff --git a/shark/shark_compile.py b/shark/shark_compile.py index 79431155f5..19f970b8d5 100644 --- a/shark/shark_compile.py +++ b/shark/shark_compile.py @@ -2,6 +2,37 @@ import tempfile from shark.shark_inference import SharkInference from shark.shark_importer import import_with_fx +import torch +import torch_mlir +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from typing import List, Tuple +from io import BytesIO +from brevitas_examples.llm.llm_quant.quantize import quantize_model +from brevitas_examples.llm.llm_quant.run_utils import get_model_impl + +def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: + if len(lhs) == 3 and len(rhs) == 2: + return [lhs[0], lhs[1], rhs[0]] + elif len(lhs) == 2 and len(rhs) == 2: + return [lhs[0], rhs[0]] + else: + raise ValueError("Input shapes not supported.") + + +def brevitas〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: + # output dtype is the dtype of the lhs float input + lhs_rank, lhs_dtype = lhs_rank_dtype + return lhs_dtype + + +def brevitas〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: + return + + +brevitas_matmul_rhs_group_quant_library = [ + brevitas〇matmul_rhs_group_quant〡shape, + brevitas〇matmul_rhs_group_quant〡dtype, + brevitas〇matmul_rhs_group_quant〡has_value_semantics] def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]): @@ -39,11 +70,80 @@ def compile_module( return shark_module +def compile_int_precision(model, inputs, precision, device, generate_vmfb, extended_model_name): + weight_bit_width = 4 if precision == "int4" else 8 + weight_group_size = 128 + quantize_model( + get_model_impl(model), + dtype=torch.float32, + weight_quant_type="asym", + weight_bit_width=weight_bit_width, + weight_param_method="stats", + weight_scale_precision="float", + weight_quant_granularity="per_group", + weight_group_size=weight_group_size, + quantize_weight_zero_point=False, + input_bit_width=None, + input_scale_type="float", + input_param_method="stats", + input_quant_type="asym", + input_quant_granularity="per_tensor", + quantize_input_zero_point=False, + seqlen=2048, + ) + print("Weight quantization applied.") + torchscript_module = import_with_fx( + model, + inputs, + precision=precision, + mlir_type="torchscript", + ) + mlir_module = torch_mlir.compile( + torchscript_module, + inputs, + output_type="torch", + backend_legal_ops=["brevitas.matmul_rhs_group_quant"], + extra_library=brevitas_matmul_rhs_group_quant_library, + use_tracing=False, + verbose=False, + ) + print(f"[DEBUG] converting torch to linalg") + run_pipeline_with_repro_report( + mlir_module, + "builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)", + description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", + ) + from contextlib import redirect_stdout + + mlir_file_path = os.path.join(os.getcwd(), f"{extended_model_name}_linalg.mlir") + with open(mlir_file_path, 'w') as f: + with redirect_stdout(f): + print(mlir_module.operation.get_asm()) + mlir_module = str(mlir_module) + mlir_module = mlir_module.encode("UTF-8") + mlir_module = BytesIO(mlir_module) + bytecode = mlir_module.read() + print(f"Elided IR written for {extended_model_name}") + return bytecode + shark_module = SharkInference( + mlir_module=bytecode, device=device, mlir_dialect="tm_tensor" + ) + extra_args = [ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ] + return ( + compile_module(shark_module, extended_model_name=extended_model_name, generate_vmfb=generate_vmfb, extra_args=extra_args), + bytecode + ) + def shark_compile_through_fx( model, inputs, extended_model_name, - is_f16=False, + precision, f16_input_mask=None, save_dir=tempfile.gettempdir(), debug=False, @@ -52,6 +152,7 @@ def shark_compile_through_fx( device=None, mlir_dialect="tm_tensor", ): + is_f16 = precision == "fp16" if generate_or_load_vmfb: shark_module = load_vmfb( extended_model_name=extended_model_name, @@ -70,18 +171,27 @@ def shark_compile_through_fx( if "cuda" in device: shark_args.enable_tf32 = True - ( - mlir_module, - _, - ) = import_with_fx( - model=model, - inputs=inputs, - is_f16=is_f16, - f16_input_mask=f16_input_mask, - debug=debug, - model_name=extended_model_name, - save_dir=save_dir, - ) + if precision in ["int4", "int8"]: + mlir_module = compile_int_precision(model, inputs, precision, device, generate_or_load_vmfb, extended_model_name) + extra_args = [ + "--iree-hal-dump-executable-sources-to=ies", + "--iree-vm-target-truncate-unsupported-floats", + "--iree-codegen-check-ir-before-llvm-conversion=false", + "--iree-vm-bytecode-module-output-format=flatbuffer-binary", + ] + else: + ( + mlir_module, + _, + ) = import_with_fx( + model=model, + inputs=inputs, + is_f16=is_f16, + f16_input_mask=f16_input_mask, + debug=debug, + model_name=extended_model_name, + save_dir=save_dir, + ) shark_module = SharkInference( mlir_module,