Skip to content

Commit

Permalink
Formatting and init files.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Feb 5, 2024
1 parent c4f3526 commit 198c42c
Show file tree
Hide file tree
Showing 13 changed files with 118 additions and 55 deletions.
7 changes: 5 additions & 2 deletions apps/shark_studio/api/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from threading import Thread

from apps.shark_studio.modules.timer import startup_timer

# from apps.shark_studio.web.utils.tmp_configs import (
# config_tmp,
# clear_tmp_mlir,
Expand Down Expand Up @@ -88,7 +89,9 @@ def dumpstacks():
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware

app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.middleware_stack = (
None # reset current middleware to allow modifying user provided list
)
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
Expand All @@ -104,7 +107,7 @@ def configure_cors_middleware(app):
"allow_credentials": True,
}
if cmd_opts.api_accept_origin:
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(',')
cors_options["allow_origins"] = cmd_opts.api_accept_origin.split(",")

app.add_middleware(CORSMiddleware, **cors_options)

Expand Down
22 changes: 16 additions & 6 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,15 @@ def __init__(
self.file_spec += "_streaming"
self.streaming_llm = streaming_llm

self.tempfile_name = get_resource_path(os.path.join("..", f"{self.file_spec}.tempfile"))
self.tempfile_name = get_resource_path(
os.path.join("..", f"{self.file_spec}.tempfile")
)
# TODO: Tag vmfb with target triple of device instead of HAL backend
self.vmfb_name = str(get_resource_path(
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
))
self.vmfb_name = str(
get_resource_path(
os.path.join("..", f"{self.file_spec}_{self.backend}.vmfb.tempfile")
)
)

self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
Expand Down Expand Up @@ -253,7 +257,10 @@ def format_out(results):
token_len += 1

history.append(format_out(token))
while format_out(token) != llm_model_map["llama2_7b"]["stop_token"] and len(history) < self.max_tokens:
while (
format_out(token) != llm_model_map["llama2_7b"]["stop_token"]
and len(history) < self.max_tokens
):
dec_time = time.time()
if self.streaming_llm and self.model["get_seq_step"]() > 600:
print("Evicting cache space!")
Expand Down Expand Up @@ -378,7 +385,9 @@ def llm_chat_api(InputData: dict):
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = append_user_prompt(InputData["messages"][0]["role"], InputData["messages"][0]["content"])
prompt = append_user_prompt(
InputData["messages"][0]["role"], InputData["messages"][0]["content"]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
Expand Down Expand Up @@ -413,6 +422,7 @@ def llm_chat_api(InputData: dict):
"choices": choices,
}


if __name__ == "__main__":
lm = LanguageModel(
"Trelis/Llama-2-7b-chat-hf-function-calling-v2",
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def get_devices_by_name(driver_name):
available_devices.extend(cpu_device)
return available_devices


# Generate and return a new seed if the provided one is not in the
# supported range (including -1)
def sanitize_seed(seed: int | str):
Expand Down
10 changes: 6 additions & 4 deletions apps/shark_studio/modules/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ def processLoRA(model, use_lora, splitting_prefix, lora_strength=0.75):
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0,
(
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0
),
)

# Directly update weight in model
Expand Down
6 changes: 2 additions & 4 deletions apps/shark_studio/modules/img_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ def process_sd_init_image(self, sd_init_image, resample_type):
if isinstance(sd_init_image, str):
if os.path.isfile(sd_init_image):
sd_init_image = Image.open(sd_init_image, mode="r").convert("RGB")
image, is_img2img = self.process_sd_init_image(
sd_init_image, resample_type
)
image, is_img2img = self.process_sd_init_image(sd_init_image, resample_type)
else:
image = None
is_img2img = False
Expand All @@ -201,4 +199,4 @@ def process_sd_init_image(self, sd_init_image, resample_type):
image_arr = 2 * (image_arr - 0.5)
is_img2img = True
image = image_arr
return image, is_img2img
return image, is_img2img
48 changes: 24 additions & 24 deletions apps/shark_studio/modules/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,30 @@ def get_schedulers(model_id):
schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
schedulers["DPMSolverMultistepKarras"] = (
DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
)
schedulers["DPMSolverMultistepKarras++"] = (
DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
Expand All @@ -83,11 +83,11 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
schedulers["KDPM2AncestralDiscrete"] = (
KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
# for i in shark_sd_fn_dict_input(sd_kwargs):
# print(i)


class LLMAPITest(unittest.TestCase):
def test01_LLMSmall(self):
lm = LanguageModel(
Expand Down
41 changes: 41 additions & 0 deletions apps/shark_studio/tests/export_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from diffusers import (
UNet2DConditionModel,
)
from torch.fx.experimental.proxy_tensor import make_fx


class UnetModel(torch.nn.Module):
def __init__(self, hf_model_name):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
hf_model_name,
subfolder="unet",
)

def forward(self, sample, timestep, encoder_hidden_states, guidance_scale):
samples = torch.cat([sample] * 2)
unet_out = self.unet.forward(
samples, timestep, encoder_hidden_states, return_dict=False
)[0]
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
return noise_pred


if __name__ == "__main__":
hf_model_name = "CompVis/stable-diffusion-v1-4"
unet = UnetModel(hf_model_name)
inputs = (torch.randn(1, 4, 64, 64), 1, torch.randn(2, 77, 768), 7.5)

fx_g = make_fx(
unet,
decomposition_table={},
tracing_mode="symbolic",
_allow_non_fake_inputs=True,
_allow_fake_constant=False,
)(*inputs)

print(fx_g)
11 changes: 6 additions & 5 deletions apps/shark_studio/tests/rest_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ def llm_chat_test(verbose=False):

data = {
"model": "Trelis/Llama-2-7b-chat-hf-function-calling-v2",
"messages": [{
"role": "",
"content": prompt,
}],
"messages": [
{
"role": "",
"content": prompt,
}
],
"device": "vulkan://0",
"max_tokens": 4096,

}

res = requests.post(url=url, json=data, headers=headers, timeout=1000)
Expand Down
9 changes: 5 additions & 4 deletions apps/shark_studio/web/api/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from fastapi.encoders import jsonable_encoder

from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
#from sdapi_v1 import shark_sd_api

# from sdapi_v1 import shark_sd_api
from apps.shark_studio.api.llm import llm_chat_api


Expand Down Expand Up @@ -155,7 +156,7 @@ def handle_exception(request: Request, e: Exception):
)
else:
print(message)
raise(e)
raise (e)
return JSONResponse(
status_code=vars(e).get("status_code", 500),
content=jsonable_encoder(err),
Expand Down Expand Up @@ -183,8 +184,8 @@ def __init__(self, app: FastAPI, queue_lock: Lock):
self.app = app
self.queue_lock = queue_lock
api_middleware(self.app)
#self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
#self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["POST"])
# self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse)
# self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse)
Expand Down
1 change: 1 addition & 0 deletions apps/shark_studio/web/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

def create_api(app):
from apps.shark_studio.web.api.compat import ApiCompat, FIFOLock

queue_lock = FIFOLock()
api = ApiCompat(app, queue_lock)
return api
Expand Down
16 changes: 10 additions & 6 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ def import_original(original_img, width, height):
label=f"Custom VAE Models",
info=sd_vae_info,
elem_id="custom_model",
value=os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None",
value=(
os.path.basename(cmd_opts.custom_vae)
if cmd_opts.custom_vae
else "None"
),
choices=["None"] + get_checkpoints("vae"),
allow_custom_value=True,
scale=1,
Expand Down Expand Up @@ -624,9 +626,11 @@ def import_original(original_img, width, height):
load_sd_config = gr.FileExplorer(
label="Load Config",
file_count="single",
root=cmd_opts.configs_path
if cmd_opts.configs_path
else get_configs_path(),
root=(
cmd_opts.configs_path
if cmd_opts.configs_path
else get_configs_path()
),
height=75,
)
load_sd_config.change(
Expand Down
Empty file.

0 comments on commit 198c42c

Please sign in to comment.