From 5779e8c039b0a980ddbe2a835e64ba5eeb3ed285 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 29 Jun 2023 15:35:51 -0500 Subject: [PATCH] int4/int8 vicuna download support (#1609) * set task_topology_max_group to cpu_count by default. Can be overriden with a flag of the same str * add download for int4/int8 mlir --- apps/language_models/src/pipelines/vicuna_pipeline.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 70d9cced95..1467a0e1de 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -38,9 +38,6 @@ def __init__( super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device - if precision in ["int4", "int8"]: - print("int4 and int8 are not supported yet, using fp32") - precision = "fp32" self.precision = precision self.first_vicuna_vmfb_path = first_vicuna_vmfb_path self.second_vicuna_vmfb_path = second_vicuna_vmfb_path @@ -103,7 +100,7 @@ def compile_first_vicuna(self): else: mlir_generated = False if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16"]: + if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank for fp32/fp16 download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", @@ -245,7 +242,7 @@ def compile_second_vicuna(self): else: mlir_generated = False if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16"]: + if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank for fp32/fp16 download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",