-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
760 additions
and
329 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from turbine_models.custom_models.sd_inference import clip, unet, vae | ||
from shark.iree_utils.compile_utils import get_iree_compiled_module | ||
from apps.shark_studio.api.utils import get_resource_path | ||
import iree.runtime as ireert | ||
import gc | ||
import torch | ||
|
||
sd_model_map = { | ||
"sd15": { | ||
"base_model_id": "runwayml/stable-diffusion-v1-5" | ||
"clip": { | ||
"initializer": clip.export_clip_model, | ||
"max_tokens": 77, | ||
} | ||
"unet": { | ||
"initializer": unet.export_unet_model, | ||
"max_tokens": 512, | ||
} | ||
"vae_decode": { | ||
"initializer": vae.export_vae_model,, | ||
} | ||
} | ||
} | ||
|
||
|
||
class SharkStableDiffusionPipeline: | ||
def __init__( | ||
self, model_name, , device=None, precision="fp32" | ||
): | ||
print(sd_model_map[model_name]) | ||
self.hf_model_name = llm_model_map[model_name]["hf_model_name"] | ||
self.torch_ir, self.tokenizer = llm_model_map[model_name][ | ||
"initializer" | ||
](self.hf_model_name, hf_auth_token, compile_to="torch") | ||
self.tempfile_name = get_resource_path("llm.torch.tempfile") | ||
with open(self.tempfile_name, "w+") as f: | ||
f.write(self.torch_ir) | ||
del self.torch_ir | ||
gc.collect() | ||
|
||
self.device = device | ||
self.precision = precision | ||
self.max_tokens = llm_model_map[model_name]["max_tokens"] | ||
self.iree_module_dict = None | ||
self.compile() | ||
|
||
def compile(self) -> None: | ||
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink". | ||
self.iree_module_dict = get_iree_compiled_module( | ||
self.tempfile_name, device=self.device, frontend="torch" | ||
) | ||
# TODO: delete the temp file | ||
|
||
def generate_images( | ||
self, | ||
prompt, | ||
): | ||
history = [] | ||
for iter in range(self.max_tokens): | ||
input_tensor = self.tokenizer( | ||
prompt, return_tensors="pt" | ||
).input_ids | ||
device_inputs = [ | ||
ireert.asdevicearray( | ||
self.iree_module_dict["config"], input_tensor | ||
) | ||
] | ||
if iter == 0: | ||
token = torch.tensor( | ||
self.iree_module_dict["vmfb"]["run_initialize"]( | ||
*device_inputs | ||
).to_host()[0][0] | ||
) | ||
else: | ||
token = torch.tensor( | ||
self.iree_module_dict["vmfb"]["run_forward"]( | ||
*device_inputs | ||
).to_host()[0][0] | ||
) | ||
|
||
history.append(token) | ||
yield self.tokenizer.decode(history) | ||
|
||
if token == llm_model_map["llama2_7b"]["stop_token"]: | ||
break | ||
|
||
for i in range(len(history)): | ||
if type(history[i]) != int: | ||
history[i] = int(history[i]) | ||
result_output = self.tokenizer.decode(history) | ||
yield result_output | ||
|
||
|
||
if __name__ == "__main__": | ||
lm = LanguageModel( | ||
"llama2_7b", | ||
hf_auth_token="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", | ||
device="cpu-task", | ||
) | ||
print("model loaded") | ||
for i in lm.chat("Hello, I am a robot."): | ||
print(i) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
import base64 | ||
import io | ||
import os | ||
import time | ||
import datetime | ||
import uvicorn | ||
import ipaddress | ||
import requests | ||
import gradio as gr | ||
from threading import Lock | ||
from io import BytesIO | ||
from fastapi import APIRouter, Depends, FastAPI, Request, Response | ||
from fastapi.security import HTTPBasic, HTTPBasicCredentials | ||
from fastapi.exceptions import HTTPException | ||
from fastapi.responses import JSONResponse | ||
from fastapi.encoders import jsonable_encoder | ||
|
||
from apps.shark_studio. import sd_samplers, postprocessing, errors, restart | ||
from sdapi_v1 import shark_sd_api | ||
from api.llm import chat_api | ||
|
||
|
||
def decode_base64_to_image(encoding): | ||
if encoding.startswith("http://") or encoding.startswith("https://"): | ||
if not opts.api_enable_requests: | ||
raise HTTPException(status_code=500, detail="Requests not allowed") | ||
|
||
if opts.api_forbid_local_requests and not verify_url(encoding): | ||
raise HTTPException(status_code=500, detail="Request to local resource not allowed") | ||
|
||
headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} | ||
response = requests.get(encoding, timeout=30, headers=headers) | ||
try: | ||
image = Image.open(BytesIO(response.content)) | ||
return image | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail="Invalid image url") from e | ||
|
||
if encoding.startswith("data:image/"): | ||
encoding = encoding.split(";")[1].split(",")[1] | ||
try: | ||
image = Image.open(BytesIO(base64.b64decode(encoding))) | ||
return image | ||
except Exception as e: | ||
raise HTTPException(status_code=500, detail="Invalid encoded image") from e | ||
|
||
|
||
def encode_pil_to_base64(image): | ||
with io.BytesIO() as output_bytes: | ||
|
||
if opts.samples_format.lower() == 'png': | ||
use_metadata = False | ||
metadata = PngImagePlugin.PngInfo() | ||
for key, value in image.info.items(): | ||
if isinstance(key, str) and isinstance(value, str): | ||
metadata.add_text(key, value) | ||
use_metadata = True | ||
image.save(output_bytes, format="PNG", pnginfo=(metadata if use_metadata else None), quality=opts.jpeg_quality) | ||
|
||
elif opts.samples_format.lower() in ("jpg", "jpeg", "webp"): | ||
if image.mode == "RGBA": | ||
image = image.convert("RGB") | ||
parameters = image.info.get('parameters', None) | ||
exif_bytes = piexif.dump({ | ||
"Exif": { piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(parameters or "", encoding="unicode") } | ||
}) | ||
if opts.samples_format.lower() in ("jpg", "jpeg"): | ||
image.save(output_bytes, format="JPEG", exif = exif_bytes, quality=opts.jpeg_quality) | ||
else: | ||
image.save(output_bytes, format="WEBP", exif = exif_bytes, quality=opts.jpeg_quality) | ||
|
||
else: | ||
raise HTTPException(status_code=500, detail="Invalid image format") | ||
|
||
bytes_data = output_bytes.getvalue() | ||
|
||
return base64.b64encode(bytes_data) | ||
|
||
|
||
def api_middleware(app: FastAPI): | ||
rich_available = False | ||
try: | ||
if os.environ.get('WEBUI_RICH_EXCEPTIONS', None) is not None: | ||
import anyio # importing just so it can be placed on silent list | ||
import starlette # importing just so it can be placed on silent list | ||
from rich.console import Console | ||
console = Console() | ||
rich_available = True | ||
except Exception: | ||
pass | ||
|
||
@app.middleware("http") | ||
async def log_and_time(req: Request, call_next): | ||
ts = time.time() | ||
res: Response = await call_next(req) | ||
duration = str(round(time.time() - ts, 4)) | ||
res.headers["X-Process-Time"] = duration | ||
endpoint = req.scope.get('path', 'err') | ||
if shared.cmd_opts.api_log and endpoint.startswith('/sdapi'): | ||
print('API {t} {code} {prot}/{ver} {method} {endpoint} {cli} {duration}'.format( | ||
t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f"), | ||
code=res.status_code, | ||
ver=req.scope.get('http_version', '0.0'), | ||
cli=req.scope.get('client', ('0:0.0.0', 0))[0], | ||
prot=req.scope.get('scheme', 'err'), | ||
method=req.scope.get('method', 'err'), | ||
endpoint=endpoint, | ||
duration=duration, | ||
)) | ||
return res | ||
|
||
def handle_exception(request: Request, e: Exception): | ||
err = { | ||
"error": type(e).__name__, | ||
"detail": vars(e).get('detail', ''), | ||
"body": vars(e).get('body', ''), | ||
"errors": str(e), | ||
} | ||
if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions | ||
message = f"API error: {request.method}: {request.url} {err}" | ||
if rich_available: | ||
print(message) | ||
console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) | ||
else: | ||
errors.report(message, exc_info=True) | ||
return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) | ||
|
||
@app.middleware("http") | ||
async def exception_handling(request: Request, call_next): | ||
try: | ||
return await call_next(request) | ||
except Exception as e: | ||
return handle_exception(request, e) | ||
|
||
@app.exception_handler(Exception) | ||
async def fastapi_exception_handler(request: Request, e: Exception): | ||
return handle_exception(request, e) | ||
|
||
@app.exception_handler(HTTPException) | ||
async def http_exception_handler(request: Request, e: HTTPException): | ||
return handle_exception(request, e) | ||
|
||
|
||
class ApiCompat: | ||
def __init__(self, queue_lock: Lock): | ||
|
||
self.router = APIRouter() | ||
self.app = FastAPI() | ||
self.queue_lock = queue_lock | ||
api_middleware(self.app) | ||
self.add_api_route("/sdapi/v1/txt2img", shark_sd_api, methods=["post"]) | ||
self.add_api_route("/sdapi/v1/img2img", shark_sd_api, methods=["post"]) | ||
#self.add_api_route("/sdapi/v1/upscaler", self.upscaler_api, methods=["post"]) | ||
#self.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=models.ExtrasSingleImageResponse) | ||
#self.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=models.ExtrasBatchImagesResponse) | ||
#self.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=models.PNGInfoResponse) | ||
#self.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=models.ProgressResponse) | ||
#self.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/skip", self.skip, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=models.OptionsModel) | ||
#self.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) | ||
#self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) | ||
#self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) | ||
#self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) | ||
#self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) | ||
#self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) | ||
#self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) | ||
#self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) | ||
#self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) | ||
#self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) | ||
#self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) | ||
#self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) | ||
#self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) | ||
#self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) | ||
#self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=models.TrainResponse) | ||
#self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=models.TrainResponse) | ||
#self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=models.MemoryResponse) | ||
#self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"]) | ||
#self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=models.ScriptsList) | ||
#self.add_api_route("/sdapi/v1/script-info", self.get_script_info, methods=["GET"], response_model=List[models.ScriptInfo]) | ||
|
||
|
||
# chat APIs needed for compatibility with multiple extensions using OpenAI API | ||
self.add_api_route( | ||
"/v1/chat/completions", chat_api, methods=["post"] | ||
) | ||
self.add_api_route("/v1/completions", chat_api, methods=["post"]) | ||
self.add_api_route("/chat/completions", chat_api, methods=["post"]) | ||
self.add_api_route("/completions", chat_api, methods=["post"]) | ||
self.add_api_route( | ||
"/v1/engines/codegen/completions", chat_api, methods=["post"] | ||
) | ||
if studio.cmd_opts.api_server_stop: | ||
self.add_api_route("/sdapi/v1/server-kill", self.kill_studio, methods=["POST"]) | ||
self.add_api_route("/sdapi/v1/server-restart", self.restart_studio, methods=["POST"]) | ||
self.add_api_route("/sdapi/v1/server-stop", self.stop_studio, methods=["POST"]) | ||
|
||
self.default_script_arg_txt2img = [] | ||
self.default_script_arg_img2img = [] | ||
|
||
def add_api_route(self, path:str, endpoint, **kwargs): | ||
if studio.cmd_opts.api_auth: | ||
return self.app.add_api_route(path, endpoint, dependencies=[Depends(self.auth)], **kwargs | ||
return self.app.add_api_route(path, endpoint, **kwargs) | ||
|
||
def refresh_checkpoints(self): | ||
with self.queue_lock: | ||
studio_data.refresh_checkpoints() | ||
|
||
def refresh_vae(self): | ||
with self.queue_lock: | ||
studio_data.refresh_vae_list() | ||
|
||
def unloadapi(self): | ||
unload_model_weights() | ||
|
||
return {} | ||
|
||
def reloadapi(self): | ||
reload_model_weights() | ||
|
||
return {} | ||
|
||
def skip(self): | ||
studio.state.skip() | ||
|
||
def launch(self, server_name, port, root_path): | ||
self.app.include_router(self.router) | ||
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=studio.cmd_opts.timeout_keep_alive, root_path=root_path) | ||
|
||
def kill_studio(self): | ||
restart.stop_program() | ||
|
||
def restart_studio(self): | ||
if restart.is_restartable(): | ||
restart.restart_program() | ||
return Response(status_code=501) | ||
|
||
def preprocess(self, args: dict): | ||
try: | ||
studio.state.begin(job="preprocess") | ||
preprocess(**args) | ||
studio.state.end() | ||
return models.PreprocessResponse(info='preprocess complete') | ||
except: | ||
studio.state.end() | ||
|
||
def stop_studio(request): | ||
studio.state.server_command = "stop" | ||
return Response("Stopping.") |
Oops, something went wrong.