Skip to content

Commit

Permalink
fp16 fixes for webui
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jun 22, 2023
1 parent 88cc242 commit 0f2f671
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
__pycache__/
*.py[cod]
*$py.class
*.mlir
*.vmfb

# C extensions
*.so
Expand Down
23 changes: 18 additions & 5 deletions apps/language_models/src/pipelines/vicuna_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def __init__(
max_num_tokens=512,
device="cuda",
precision="fp32",
first_vicuna_mlir_path=Path("first_vicuna.mlir"),
second_vicuna_mlir_path=Path("second_vicuna.mlir"),
first_vicuna_vmfb_path=Path("first_vicuna.vmfb"),
second_vicuna_vmfb_path=Path("second_vicuna.vmfb"),
first_vicuna_mlir_path=None,
second_vicuna_mlir_path=None,
first_vicuna_vmfb_path=None,
second_vicuna_vmfb_path=None,
load_mlir_from_shark_tank=True,
) -> None:
super().__init__(model_name, hf_model_path, max_num_tokens)
Expand All @@ -42,10 +42,23 @@ def __init__(
self.second_vicuna_vmfb_path = second_vicuna_vmfb_path
self.first_vicuna_mlir_path = first_vicuna_mlir_path
self.second_vicuna_mlir_path = second_vicuna_mlir_path
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank
if self.first_vicuna_mlir_path==None:
self.first_vicuna_mlir_path = self.get_model_path()
if self.second_vicuna_mlir_path==None:
self.second_vicuna_mlir_path = self.get_model_path("second")
if self.first_vicuna_vmfb_path==None:
self.first_vicuna_vmfb_path= self.get_model_path(suffix="vmfb")
if self.second_vicuna_vmfb_path==None:
self.second_vicuna_vmfb_path= self.get_model_path("second", "vmfb")
self.tokenizer = self.get_tokenizer()
self.shark_model = self.compile()
self.load_mlir_from_shark_tank = load_mlir_from_shark_tank

def get_model_path(self, model_number="first", suffix="mlir"):
safe_device = "_".join(self.device.split("-"))
if suffix=="mlir":
return Path(f'{model_number}_vicuna_{self.precision}.{suffix}')
return Path(f'{model_number}_vicuna_{safe_device}_{self.precision}.{suffix}')
def get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.hf_model_path, use_fast=False
Expand Down
5 changes: 4 additions & 1 deletion apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,10 @@ def get_devices_by_name(driver_name):
available_devices.extend(vulkan_devices)
cuda_devices = get_devices_by_name("cuda")
available_devices.extend(cuda_devices)
available_devices.append("device => cpu")
cpu_device = get_devices_by_name("cpu-sync")
available_devices.extend(cpu_device)
cpu_device = get_devices_by_name("cpu-task")
available_devices.extend(cpu_device)
return available_devices


Expand Down
16 changes: 9 additions & 7 deletions apps/stable_diffusion/web/ui/stablelm_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,21 @@ def chat(curr_system_message, history, model, device, precision):

curr_system_message = start_message_vicuna
if vicuna_model == 0:
first_vic_vmfb_path = Path("first_vicuna.vmfb")
second_vic_vmfb_path = Path("second_vicuna.vmfb")
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = Vicuna(
"vicuna",
hf_model_path=model,
device=device,
precision=precision,
first_vicuna_vmfb_path=first_vic_vmfb_path,
second_vicuna_vmfb_path=second_vic_vmfb_path,
)
messages = curr_system_message + "".join(
[
Expand Down Expand Up @@ -120,9 +124,7 @@ def chat(curr_system_message, history, model, device, precision):
"TheBloke/vicuna-7B-1.1-HF",
],
)
supported_devices = [
device for device in available_devices if "cuda" in device
]
supported_devices = available_devices
enabled = len(supported_devices) > 0
device = gr.Dropdown(
label="Device",
Expand Down

0 comments on commit 0f2f671

Please sign in to comment.