Skip to content

Commit

Permalink
Update precision check for vicuna (#1610)
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 authored Jun 29, 2023
1 parent 5779e8c commit 534de05
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def __init__(
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if not load_mlir_from_shark_tank and precision in ["int4", "int8"]:
print(
"int4 and int8 are only available from SHARK tank, please set --load_mlir_from_shark_tank, using fp32 now"
)
precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
Expand Down Expand Up @@ -101,7 +106,7 @@ def compile_first_vicuna(self):
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank for fp32/fp16
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
self.first_vicuna_mlir_path.absolute(),
Expand All @@ -118,7 +123,7 @@ def compile_first_vicuna(self):
)
else:
print(
f"Only fp32 and fp16 mlir added to tank, generating {self.precision} mlir on device."
f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device."
)

if not mlir_generated:
Expand Down Expand Up @@ -243,7 +248,7 @@ def compile_second_vicuna(self):
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank for fp32/fp16
# download MLIR from shark_tank
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
self.second_vicuna_mlir_path.absolute(),
Expand All @@ -260,7 +265,7 @@ def compile_second_vicuna(self):
)
else:
print(
"Only fp32 mlir added to tank, generating mlir on device."
"Only fp32/fp16/int8/int4 mlir added to tank, generating mlir on device."
)

if not mlir_generated:
Expand Down

0 comments on commit 534de05

Please sign in to comment.