Skip to content

Commit

Permalink
int4/int8 vicuna download support (#1609)
Browse files Browse the repository at this point in the history
* set task_topology_max_group to cpu_count

by default. Can be overriden with a flag of the same str

* add download for int4/int8 mlir
  • Loading branch information
dan-garvey committed Jun 29, 2023
1 parent d496053 commit 5779e8c
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def __init__(
super().__init__(model_name, hf_model_path, max_num_tokens)
self.max_sequence_length = 256
self.device = device
if precision in ["int4", "int8"]:
print("int4 and int8 are not supported yet, using fp32")
precision = "fp32"
self.precision = precision
self.first_vicuna_vmfb_path = first_vicuna_vmfb_path
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
Expand Down Expand Up @@ -103,7 +100,7 @@ def compile_first_vicuna(self):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16"]:
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank for fp32/fp16
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.first_vicuna_mlir_path.name}",
Expand Down Expand Up @@ -245,7 +242,7 @@ def compile_second_vicuna(self):
else:
mlir_generated = False
if self.load_mlir_from_shark_tank:
if self.precision in ["fp32", "fp16"]:
if self.precision in ["fp32", "fp16", "int8", "int4"]:
# download MLIR from shark_tank for fp32/fp16
download_public_file(
f"gs://shark_tank/vicuna/unsharded/mlir/{self.second_vicuna_mlir_path.name}",
Expand Down

0 comments on commit 5779e8c

Please sign in to comment.