-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add int4/int8 vicuna #1598
Add int4/int8 vicuna #1598
Conversation
3fadf17
to
01195a8
Compare
7fc74a3
to
66bbf06
Compare
439d38e
to
9e6cbeb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update your patch as per the comment - precision
in StableDiffusion's API won't be needed because it's already there.
Can you please resolve such failures we see in CI? |
We can add Brevitas to the requirements.txt with a git install |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address one comment - rest looks good.
# brevitas custom op lib | ||
apps/language_models/scripts/vicuna.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem correct. Can you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does ignore this file from the format check. How does it look not corret for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you add it here it wont be detected by git. You want to add it to the black format command line with a --exclude
or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address the following sub-comments and also make changes to the following file in order to integrate with WebUI as well :-
Replace the set of lines at : https://github.com/nod-ai/SHARK/blob/91ab594744ffbe982ababe340fcd208923ecae48/apps/stable_diffusion/web/ui/stablelm_ui.py#L45-L47
With :
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
Replace the following : https://github.com/nod-ai/SHARK/blob/91ab594744ffbe982ababe340fcd208923ecae48/apps/stable_diffusion/web/ui/stablelm_ui.py#L61
With :
vicuna_model = UnshardedVicuna(
self.shark_model = self.compile() | ||
|
||
def get_model_path(self, model_number="first", suffix="mlir"): | ||
safe_device = "_".join(self.device.split("-")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace this with safe_device = self.device.split("-")[0]
def compile(self): | ||
# Cannot load both the models in the memory at once | ||
# due to memory constraints, hence on demand compilation | ||
# is being used until the space is enough for both models | ||
|
||
# Testing : DO NOT Download Vmfbs if not found. Modify later | ||
# download vmfbs for A100 | ||
if ( | ||
not self.first_vicuna_vmfb_path.exists() | ||
and self.device in ["cuda", "cpu"] | ||
and self.precision in ["fp32", "fp16"] | ||
): | ||
# combinations that are still in the works | ||
if not (self.device == "cuda" and self.precision == "fp16"): | ||
# Will generate vmfb on device | ||
pass | ||
else: | ||
download_public_file( | ||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}", | ||
self.first_vicuna_vmfb_path.absolute(), | ||
single_file=True, | ||
) | ||
else: | ||
# get first vic | ||
# TODO: Remove after testing to avoid memory overload | ||
# fvic_shark_model = self.compile_first_vicuna() | ||
pass | ||
if ( | ||
not self.second_vicuna_vmfb_path.exists() | ||
and self.device in ["cuda", "cpu"] | ||
and self.precision in ["fp32", "fp16"] | ||
): | ||
# combinations that are still in the works | ||
if not (self.device == "cuda" and self.precision == "fp16"): | ||
# Will generate vmfb on device | ||
pass | ||
else: | ||
download_public_file( | ||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}", | ||
self.second_vicuna_vmfb_path.absolute(), | ||
single_file=True, | ||
) | ||
else: | ||
# get second vic | ||
# TODO: Remove after testing to avoid memory overload | ||
# svic_shark_model = self.compile_second_vicuna() | ||
pass | ||
|
||
return None | ||
# return tuple of shark_modules once mem is supported | ||
# return fvic_shark_model, svic_shark_model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace this with :-
def compile(self):
# Cannot load both the models in the memory at once
# due to memory constraints, hence on demand compilation
# is being used until the space is enough for both models
# Testing : DO NOT Download Vmfbs if not found. Modify later
# download vmfbs for A100
supported_devices = ["cuda", "cpu-sync", "cpu-task", "cpu"]
if (
not self.first_vicuna_vmfb_path.exists()
and self.device in supported_devices
and self.precision in ["fp32", "fp16", "int8"]
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8"
):
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.first_vicuna_vmfb_path.name}",
self.first_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
pass
else:
# get first vic
# TODO: Remove after testing to avoid memory overload
# fvic_shark_model = self.compile_first_vicuna()
pass
if (
not self.second_vicuna_vmfb_path.exists()
and self.device in supported_devices
and self.precision in ["fp32", "fp16", "int8"]
):
if (self.device == "cuda" and self.precision == "fp16") or (
self.device in ["cpu-sync", "cpu-task"]
and self.precision == "int8"
):
download_public_file(
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.second_vicuna_vmfb_path.name}",
self.second_vicuna_vmfb_path.absolute(),
single_file=True,
)
else:
pass
else:
# get second vic
# TODO: Remove after testing to avoid memory overload
# svic_shark_model = self.compile_second_vicuna()
pass
return None
# return tuple of shark_modules once mem is supported
# return fvic_shark_model, svic_shark_model
brevitas_matmul_rhs_group_quant_library
and vicuna pipeline to main file to avoid a decompose errorbrevitas〇matmul_rhs_group_quant
to avoid a signature mismatch errorTODO: