diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 70d9cced95..1467a0e1de 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -38,9 +38,6 @@ def __init__( super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device - if precision in ["int4", "int8"]: - print("int4 and int8 are not supported yet, using fp32") - precision = "fp32" self.precision = precision self.first_vicuna_vmfb_path = first_vicuna_vmfb_path self.second_vicuna_vmfb_path = second_vicuna_vmfb_path @@ -103,7 +100,7 @@ def compile_first_vicuna(self): else: mlir_generated = False if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16"]: + if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank for fp32/fp16 download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", @@ -245,7 +242,7 @@ def compile_second_vicuna(self): else: mlir_generated = False if self.load_mlir_from_shark_tank: - if self.precision in ["fp32", "fp16"]: + if self.precision in ["fp32", "fp16", "int8", "int4"]: # download MLIR from shark_tank for fp32/fp16 download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",