diff --git a/apps/language_models/src/pipelines/vicuna_pipeline.py b/apps/language_models/src/pipelines/vicuna_pipeline.py index 1467a0e1de..c859d7e243 100644 --- a/apps/language_models/src/pipelines/vicuna_pipeline.py +++ b/apps/language_models/src/pipelines/vicuna_pipeline.py @@ -38,6 +38,11 @@ def __init__( super().__init__(model_name, hf_model_path, max_num_tokens) self.max_sequence_length = 256 self.device = device + if not load_mlir_from_shark_tank and precision in ["int4", "int8"]: + print( + "int4 and int8 are only available from SHARK tank, please set --load_mlir_from_shark_tank, using fp32 now" + ) + precision = "fp32" self.precision = precision self.first_vicuna_vmfb_path = first_vicuna_vmfb_path self.second_vicuna_vmfb_path = second_vicuna_vmfb_path @@ -101,7 +106,7 @@ def compile_first_vicuna(self): mlir_generated = False if self.load_mlir_from_shark_tank: if self.precision in ["fp32", "fp16", "int8", "int4"]: - # download MLIR from shark_tank for fp32/fp16 + # download MLIR from shark_tank download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}", self.first_vicuna_mlir_path.absolute(), @@ -118,7 +123,7 @@ def compile_first_vicuna(self): ) else: print( - f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device." + f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device." ) if not mlir_generated: @@ -243,7 +248,7 @@ def compile_second_vicuna(self): mlir_generated = False if self.load_mlir_from_shark_tank: if self.precision in ["fp32", "fp16", "int8", "int4"]: - # download MLIR from shark_tank for fp32/fp16 + # download MLIR from shark_tank download_public_file( f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}", self.second_vicuna_mlir_path.absolute(), @@ -260,7 +265,7 @@ def compile_second_vicuna(self): ) else: print( - "Only fp32 mlir added to tank, generating mlir on device." + "Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device." ) if not mlir_generated: