diff --git a/README.md b/README.md index 4926e8e091..f680da2f7e 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,11 @@ result = shark_module.forward((arg0, arg1)) ``` +## Examples Using the REST API + +* [Setting up SHARK for use with Blender](./docs/shark_sd_blender.md) +* [Setting up SHARK for use with Koboldcpp](./docs/shark_sd_koboldcpp.md) + ## Supported and Validated Models SHARK is maintained to support the latest innovations in ML Models: diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py index e60752a96f..2eb68d6338 100644 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ b/apps/stable_diffusion/src/utils/stable_args.py @@ -253,28 +253,30 @@ def is_valid_file(arg): "--left", default=False, action=argparse.BooleanOptionalAction, - help="If expend left for outpainting.", + help="If extend left for outpainting.", ) p.add_argument( "--right", default=False, action=argparse.BooleanOptionalAction, - help="If expend right for outpainting.", + help="If extend right for outpainting.", ) p.add_argument( + "--up", "--top", default=False, action=argparse.BooleanOptionalAction, - help="If expend top for outpainting.", + help="If extend top for outpainting.", ) p.add_argument( + "--down", "--bottom", default=False, action=argparse.BooleanOptionalAction, - help="If expend bottom for outpainting.", + help="If extend bottom for outpainting.", ) p.add_argument( @@ -641,6 +643,18 @@ def is_valid_file(arg): help="Flag for enabling rest API.", ) +p.add_argument( + "--api_accept_origin", + action="append", + type=str, + help="An origin to be accepted by the REST api for Cross Origin" + "Resource Sharing (CORS). Use multiple times for multiple origins, " + 'or use --api_accept_origin="*" to accept all origins. If no origins ' + "are set no CORS headers will be returned by the api. Use, for " + "instance, if you need to access the REST api from Javascript running " + "in a web browser.", +) + p.add_argument( "--debug", default=False, diff --git a/apps/stable_diffusion/web/api/__init__.py b/apps/stable_diffusion/web/api/__init__.py new file mode 100644 index 0000000000..892d1976d7 --- /dev/null +++ b/apps/stable_diffusion/web/api/__init__.py @@ -0,0 +1 @@ +from apps.stable_diffusion.web.api.sdapi_v1 import sdapi diff --git a/apps/stable_diffusion/web/api/sdapi_v1.py b/apps/stable_diffusion/web/api/sdapi_v1.py new file mode 100644 index 0000000000..3eebd5c113 --- /dev/null +++ b/apps/stable_diffusion/web/api/sdapi_v1.py @@ -0,0 +1,579 @@ +import os + +from collections import defaultdict +from enum import Enum +from fastapi import FastAPI +from pydantic import BaseModel, Field, conlist, model_validator + +from apps.stable_diffusion.web.api.utils import ( + frozen_args, + sampler_aliases, + encode_pil_to_base64, + decode_base64_to_image, + get_model_from_request, + get_scheduler_from_request, + get_lora_params, + get_device, + GenerationInputData, + GenerationResponseData, +) + +from apps.stable_diffusion.web.ui.utils import ( + get_custom_model_files, + get_custom_model_pathfile, + predefined_models, + predefined_paint_models, + predefined_upscaler_models, + scheduler_list, +) +from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf +from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf +from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf +from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf +from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf + +sdapi = FastAPI() + + +# Rest API: /sdapi/v1/sd-models (lists available models) +class AppParam(str, Enum): + txt2img = "txt2img" + img2img = "img2img" + inpaint = "inpaint" + outpaint = "outpaint" + upscaler = "upscaler" + + +@sdapi.get( + "/v1/sd-models", + summary="lists available models", + description=( + "This is all the models that this server currently knows about.\n " + "Models listed may still have a compilation and build pending that " + "will be triggered the first time they are used." + ), +) +def sd_models_api(app: AppParam = frozen_args.app): + match app: + case "inpaint" | "outpaint": + checkpoint_type = "inpainting" + predefined = predefined_paint_models + case "upscaler": + checkpoint_type = "upscaler" + predefined = predefined_upscaler_models + case _: + checkpoint_type = "" + predefined = predefined_models + + return [ + { + "title": model_file, + "model_name": model_file, + "hash": None, + "sha256": None, + "filename": get_custom_model_pathfile(model_file), + "config": None, + } + for model_file in get_custom_model_files( + custom_checkpoint_type=checkpoint_type + ) + ] + [ + { + "title": model, + "model_name": model, + "hash": None, + "sha256": None, + "filename": None, + "config": None, + } + for model in predefined + ] + + +# Rest API: /sdapi/v1/samplers (lists schedulers) +@sdapi.get( + "/v1/samplers", + summary="lists available schedulers/samplers", + description=( + "These are all the Schedulers defined and available. Not " + "every scheduler is compatible with all apis. Aliases are " + "equivalent samplers in A1111 if they are known." + ), +) +def sd_samplers_api(): + reverse_sampler_aliases = defaultdict(list) + for key, value in sampler_aliases.items(): + reverse_sampler_aliases[value].append(key) + + return ( + { + "name": scheduler, + "aliases": reverse_sampler_aliases.get(scheduler, []), + "options": {}, + } + for scheduler in scheduler_list + ) + + +# Rest API: /sdapi/v1/options (lists application level options) +@sdapi.get( + "/v1/options", + summary="lists current settings of application level options", + description=( + "A subset of the command line arguments set at startup renamed " + "to correspond to the A1111 naming. Only a small subset of A1111 " + "options are returned." + ), +) +def options_api(): + # This is mostly just enough to support what Koboldcpp wants, with a + # few other things that seemed obvious + return { + "samples_save": True, + "samples_format": frozen_args.output_img_format, + "sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc) + if frozen_args.ckpt_loc + else frozen_args.hf_model_id, + "sd_lora": frozen_args.use_lora, + "sd_vae": frozen_args.custom_vae or "Automatic", + "enable_pnginfo": frozen_args.write_metadata_to_png, + } + + +# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings) +@sdapi.get( + "/v1/cmd-flags", + summary="lists the command line arguments value that were set on startup.", +) +def cmd_flags_api(): + return vars(frozen_args) + + +# Rest API: /sdapi/v1/txt2img (Text to image) +class ModelOverrideSettings(BaseModel): + sd_model_checkpoint: str = get_model_from_request( + fallback_model="stabilityai/stable-diffusion-2-1-base" + ) + + +class Txt2ImgInputData(GenerationInputData): + enable_hr: bool = frozen_args.use_hiresfix + hr_resize_y: int = Field( + default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8 + ) + hr_resize_x: int = Field( + default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8 + ) + override_settings: ModelOverrideSettings = None + + +@sdapi.post( + "/v1/txt2img", + summary="Does text to image generation", + response_model=GenerationResponseData, +) +def txt2img_api(InputData: Txt2ImgInputData): + model_id = get_model_from_request( + InputData, + fallback_model="stabilityai/stable-diffusion-2-1-base", + ) + scheduler = get_scheduler_from_request( + InputData, "txt2img_hires" if InputData.enable_hr else "txt2img" + ) + (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) + + print( + f"Prompt: {InputData.prompt}, " + f"Negative Prompt: {InputData.negative_prompt}, " + f"Seed: {InputData.seed}," + f"Model: {model_id}, " + f"Scheduler: {scheduler}. " + ) + + res = txt2img_inf( + InputData.prompt, + InputData.negative_prompt, + InputData.height, + InputData.width, + InputData.steps, + InputData.cfg_scale, + InputData.seed, + batch_count=InputData.n_iter, + batch_size=1, + scheduler=scheduler, + model_id=model_id, + custom_vae=frozen_args.custom_vae or "None", + precision="fp16", + device=get_device(frozen_args.device), + max_length=frozen_args.max_length, + save_metadata_to_json=frozen_args.save_metadata_to_json, + save_metadata_to_png=frozen_args.write_metadata_to_png, + lora_weights=lora_weights, + lora_hf_id=lora_hf_id, + ondemand=frozen_args.ondemand, + repeatable_seeds=False, + use_hiresfix=InputData.enable_hr, + hiresfix_height=InputData.hr_resize_y, + hiresfix_width=InputData.hr_resize_x, + hiresfix_strength=frozen_args.hiresfix_strength, + resample_type=frozen_args.resample_type, + ) + + # Since we're not streaming we just want the last generator result + for items_so_far in res: + items = items_so_far + + return { + "images": encode_pil_to_base64(items[0]), + "parameters": {}, + "info": items[1], + } + + +# Rest API: /sdapi/v1/img2img (Image to image) +class StencilParam(str, Enum): + canny = "canny" + openpose = "openpose" + scribble = "scribble" + zoedepth = "zoedepth" + + +class Img2ImgInputData(GenerationInputData): + init_images: conlist(str, min_length=1, max_length=2) + denoising_strength: float = frozen_args.strength + use_stencil: StencilParam = frozen_args.use_stencil + override_settings: ModelOverrideSettings = None + + @model_validator(mode="after") + def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData": + if ( + self.use_stencil == StencilParam.scribble + and len(self.init_images) < 2 + ): + raise ValueError( + "a second image must be supplied for the controlnet:scribble stencil" + ) + + return self + + +@sdapi.post( + "/v1/img2img", + summary="Does image to image generation", + response_model=GenerationResponseData, +) +def img2img_api( + InputData: Img2ImgInputData, +): + model_id = get_model_from_request( + InputData, + fallback_model="stabilityai/stable-diffusion-2-1-base", + ) + scheduler = get_scheduler_from_request(InputData, "img2img") + (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) + + init_image = decode_base64_to_image(InputData.init_images[0]) + mask_image = ( + decode_base64_to_image(InputData.init_images[1]) + if len(InputData.init_images) > 1 + else None + ) + + print( + f"Prompt: {InputData.prompt}, " + f"Negative Prompt: {InputData.negative_prompt}, " + f"Seed: {InputData.seed}, " + f"Model: {model_id}, " + f"Scheduler: {scheduler}." + ) + + res = img2img_inf( + InputData.prompt, + InputData.negative_prompt, + {"image": init_image, "mask": mask_image}, + InputData.height, + InputData.width, + InputData.steps, + InputData.denoising_strength, + InputData.cfg_scale, + InputData.seed, + batch_count=InputData.n_iter, + batch_size=1, + scheduler=scheduler, + model_id=model_id, + custom_vae=frozen_args.custom_vae or "None", + precision="fp16", + device=get_device(frozen_args.device), + max_length=frozen_args.max_length, + use_stencil=InputData.use_stencil, + save_metadata_to_json=frozen_args.save_metadata_to_json, + save_metadata_to_png=frozen_args.write_metadata_to_png, + lora_weights=lora_weights, + lora_hf_id=lora_hf_id, + ondemand=frozen_args.ondemand, + repeatable_seeds=False, + resample_type=frozen_args.resample_type, + ) + + # Since we're not streaming we just want the last generator result + for items_so_far in res: + items = items_so_far + + return { + "images": encode_pil_to_base64(items[0]), + "parameters": {}, + "info": items[1], + } + + +# Rest API: /sdapi/v1/inpaint (Inpainting) +class PaintModelOverideSettings(BaseModel): + sd_model_checkpoint: str = get_model_from_request( + checkpoint_type="inpainting", + fallback_model="stabilityai/stable-diffusion-2-inpainting", + ) + + +class InpaintInputData(GenerationInputData): + image: str = Field(description="Base64 encoded input image") + mask: str = Field(description="Base64 encoded mask image") + is_full_res: bool = False # Is this setting backwards in the UI? + full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4) + denoising_strength: float = frozen_args.strength + use_stencil: StencilParam = frozen_args.use_stencil + override_settings: PaintModelOverideSettings = None + + +@sdapi.post( + "/v1/inpaint", + summary="Does inpainting generation on an image", + response_model=GenerationResponseData, +) +def inpaint_api( + InputData: InpaintInputData, +): + model_id = get_model_from_request( + InputData, + checkpoint_type="inpainting", + fallback_model="stabilityai/stable-diffusion-2-inpainting", + ) + scheduler = get_scheduler_from_request(InputData, "inpaint") + (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) + + init_image = decode_base64_to_image(InputData.image) + mask = decode_base64_to_image(InputData.mask) + + print( + f"Prompt: {InputData.prompt}, " + f'Negative Prompt: {InputData.negative_prompt}", ' + f'Seed: {InputData.seed}", ' + f"Model: {model_id}, " + f"Scheduler: {scheduler}." + ) + + res = inpaint_inf( + InputData.prompt, + InputData.negative_prompt, + {"image": init_image, "mask": mask}, + InputData.height, + InputData.width, + InputData.is_full_res, + InputData.full_res_padding, + InputData.steps, + InputData.cfg_scale, + InputData.seed, + batch_count=InputData.n_iter, + batch_size=1, + scheduler=scheduler, + model_id=model_id, + custom_vae=frozen_args.custom_vae or "None", + precision="fp16", + device=get_device(frozen_args.device), + max_length=frozen_args.max_length, + save_metadata_to_json=frozen_args.save_metadata_to_json, + save_metadata_to_png=frozen_args.write_metadata_to_png, + lora_weights=lora_weights, + lora_hf_id=lora_hf_id, + ondemand=frozen_args.ondemand, + repeatable_seeds=False, + ) + + # Since we're not streaming we just want the last generator result + for items_so_far in res: + items = items_so_far + + return { + "images": encode_pil_to_base64(items[0]), + "parameters": {}, + "info": items[1], + } + + +# Rest API: /sdapi/v1/outpaint (Outpainting) +class DirectionParam(str, Enum): + left = "left" + right = "right" + up = "up" + down = "down" + + +class OutpaintInputData(GenerationInputData): + init_images: list[str] + pixels: int = Field( + default=frozen_args.pixels, ge=8, le=256, multiple_of=8 + ) + mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64) + directions: set[DirectionParam] = [ + direction + for direction in ["left", "right", "up", "down"] + if vars(frozen_args)[direction] + ] + noise_q: float = frozen_args.noise_q + color_variation: float = frozen_args.color_variation + override_settings: PaintModelOverideSettings = None + + +@sdapi.post( + "/v1/outpaint", + summary="Does outpainting generation on an image", + response_model=GenerationResponseData, +) +def outpaint_api( + InputData: OutpaintInputData, +): + model_id = get_model_from_request( + InputData, + checkpoint_type="inpainting", + fallback_model="stabilityai/stable-diffusion-2-inpainting", + ) + scheduler = get_scheduler_from_request(InputData, "outpaint") + (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) + + init_image = decode_base64_to_image(InputData.init_images[0]) + + print( + f"Prompt: {InputData.prompt}, " + f"Negative Prompt: {InputData.negative_prompt}, " + f"Seed: {InputData.seed}, " + f"Model: {model_id}, " + f"Scheduler: {scheduler}." + ) + + res = outpaint_inf( + InputData.prompt, + InputData.negative_prompt, + init_image, + InputData.pixels, + InputData.mask_blur, + InputData.directions, + InputData.noise_q, + InputData.color_variation, + InputData.height, + InputData.width, + InputData.steps, + InputData.cfg_scale, + InputData.seed, + batch_count=InputData.n_iter, + batch_size=1, + scheduler=scheduler, + model_id=model_id, + custom_vae=frozen_args.custom_vae or "None", + precision="fp16", + device=get_device(frozen_args.device), + max_length=frozen_args.max_length, + save_metadata_to_json=frozen_args.save_metadata_to_json, + save_metadata_to_png=frozen_args.write_metadata_to_png, + lora_weights=lora_weights, + lora_hf_id=lora_hf_id, + ondemand=frozen_args.ondemand, + repeatable_seeds=False, + ) + + # Since we're not streaming we just want the last generator result + for items_so_far in res: + items = items_so_far + + return { + "images": encode_pil_to_base64(items[0]), + "parameters": {}, + "info": items[1], + } + + +# Rest API: /sdapi/v1/upscaler (Upscaling) +class UpscalerModelOverideSettings(BaseModel): + sd_model_checkpoint: str = get_model_from_request( + checkpoint_type="upscaler", + fallback_model="stabilityai/stable-diffusion-x4-upscaler", + ) + + +class UpscalerInputData(GenerationInputData): + init_images: list[str] = Field( + description="Base64 encoded image to upscale" + ) + noise_level: int = frozen_args.noise_level + override_settings: UpscalerModelOverideSettings = None + + +@sdapi.post( + "/v1/upscaler", + summary="Does image upscaling", + response_model=GenerationResponseData, +) +def upscaler_api( + InputData: UpscalerInputData, +): + model_id = get_model_from_request( + InputData, + checkpoint_type="upscaler", + fallback_model="stabilityai/stable-diffusion-x4-upscaler", + ) + scheduler = get_scheduler_from_request(InputData, "upscaler") + (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) + + init_image = decode_base64_to_image(InputData.init_images[0]) + + print( + f"Prompt: {InputData.prompt}, " + f"Negative Prompt: {InputData.negative_prompt}, " + f"Seed: {InputData.seed}, " + f"Model: {model_id}, " + f"Scheduler: {scheduler}." + ) + + res = upscaler_inf( + InputData.prompt, + InputData.negative_prompt, + init_image, + InputData.height, + InputData.width, + InputData.steps, + InputData.noise_level, + InputData.cfg_scale, + InputData.seed, + batch_count=InputData.n_iter, + batch_size=1, + scheduler=scheduler, + model_id=model_id, + custom_vae=frozen_args.custom_vae or "None", + precision="fp16", + device=get_device(frozen_args.device), + max_length=frozen_args.max_length, + save_metadata_to_json=frozen_args.save_metadata_to_json, + save_metadata_to_png=frozen_args.write_metadata_to_png, + lora_weights=lora_weights, + lora_hf_id=lora_hf_id, + ondemand=frozen_args.ondemand, + repeatable_seeds=False, + ) + + # Since we're not streaming we just want the last generator result + for items_so_far in res: + items = items_so_far + + return { + "images": encode_pil_to_base64(items[0]), + "parameters": {}, + "info": items[1], + } diff --git a/apps/stable_diffusion/web/api/utils.py b/apps/stable_diffusion/web/api/utils.py new file mode 100644 index 0000000000..eca422f9dc --- /dev/null +++ b/apps/stable_diffusion/web/api/utils.py @@ -0,0 +1,211 @@ +import base64 +import pickle + +from argparse import Namespace +from fastapi.exceptions import HTTPException +from io import BytesIO +from PIL import Image +from pydantic import BaseModel, Field + +from apps.stable_diffusion.src import args +from apps.stable_diffusion.web.ui.utils import ( + available_devices, + get_custom_model_files, + predefined_models, + predefined_paint_models, + predefined_upscaler_models, + scheduler_list, + scheduler_list_cpu_only, +) + + +# Probably overly cautious, but try to ensure we only use the starting +# args in each api call, as the code does `args. = ` +# in lots of places and in testing, it seemed to me, these changes leaked +# into subsequent api calls. + +# Roundtripping through pickle for deepcopy, there is probably a better way +frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args))))) + +# an attempt to map some of the A1111 sampler names to scheduler names +# https://github.com/huggingface/diffusers/issues/4167 is where the +# (not so obvious) ones come from +sampler_aliases = { + # a1111/onnx (these point to diffusers classes in A1111) + "pndm": "PNDM", + "heun": "HeunDiscrete", + "ddim": "DDIM", + "ddpm": "DDPM", + "euler": "EulerDiscrete", + "euler-ancestral": "EulerAncestralDiscrete", + "dpm": "DPMSolverMultistep", + # a1111/k_diffusion (the obvious ones) + "Euler a": "EulerAncestralDiscrete", + "Euler": "EulerDiscrete", + "LMS": "LMSDiscrete", + "Heun": "HeunDiscrete", + # a1111/k_diffusion (not so obvious) + "DPM++ 2M": "DPMSolverMultistep", + "DPM++ 2M Karras": "DPMSolverMultistepKarras", + "DPM++ 2M SDE": "DPMSolverMultistep++", + "DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++", + "DPM2": "KDPM2Discrete", + "DPM2 a": "KDPM2AncestralDiscrete", +} + +allowed_schedulers = { + "txt2img": { + "schedulers": scheduler_list, + "fallback": "SharkEulerDiscrete", + }, + "txt2img_hires": { + "schedulers": scheduler_list_cpu_only, + "fallback": "DEISMultistep", + }, + "img2img": { + "schedulers": scheduler_list_cpu_only, + "fallback": "EulerDiscrete", + }, + "inpaint": { + "schedulers": scheduler_list_cpu_only, + "fallback": "DDIM", + }, + "outpaint": { + "schedulers": scheduler_list_cpu_only, + "fallback": "DDIM", + }, + "upscaler": { + "schedulers": scheduler_list_cpu_only, + "fallback": "DDIM", + }, +} + +# base pydantic model for sd generation apis + + +class GenerationInputData(BaseModel): + prompt: str = "" + negative_prompt: str = "" + hf_model_id: str | None = None + height: int = Field( + default=frozen_args.height, ge=128, le=768, multiple_of=8 + ) + width: int = Field( + default=frozen_args.width, ge=128, le=768, multiple_of=8 + ) + sampler_name: str = frozen_args.scheduler + cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1) + steps: int = Field(default=frozen_args.steps, ge=1, le=100) + seed: int = frozen_args.seed + n_iter: int = Field(default=frozen_args.batch_count) + + +class GenerationResponseData(BaseModel): + images: list[str] = Field(description="Generated images, Base64 encoded") + properties: dict = {} + info: str + + +# image encoding/decoding + + +def encode_pil_to_base64(images: list[Image.Image]): + encoded_imgs = [] + for image in images: + with BytesIO() as output_bytes: + if frozen_args.output_img_format.lower() == "png": + image.save(output_bytes, format="PNG") + + elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"): + image.save(output_bytes, format="JPEG") + else: + raise HTTPException( + status_code=500, detail="Invalid image format" + ) + bytes_data = output_bytes.getvalue() + encoded_imgs.append(base64.b64encode(bytes_data)) + return encoded_imgs + + +def decode_base64_to_image(encoding: str): + if encoding.startswith("data:image/"): + encoding = encoding.split(";", 1)[1].split(",", 1)[1] + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) + return image + except Exception as err: + print(err) + raise HTTPException(status_code=400, detail="Invalid encoded image") + + +# get valid sd models/vaes/schedulers etc. + + +def get_predefined_models(custom_checkpoint_type: str): + match custom_checkpoint_type: + case "inpainting": + return predefined_paint_models + case "upscaler": + return predefined_upscaler_models + case _: + return predefined_models + + +def get_model_from_request( + request_data=None, + checkpoint_type: str = "", + fallback_model: str = "", +): + model = None + if request_data: + if request_data.hf_model_id: + model = request_data.hf_model_id + elif request_data.override_settings: + model = request_data.override_settings.sd_model_checkpoint + + # if the request didn't specify a model try the command line args + result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id + + # make sure whatever we have is a valid model for the checkpoint type + if result in get_custom_model_files( + custom_checkpoint_type=checkpoint_type + ) + get_predefined_models(checkpoint_type): + return result + # if not return what was specified as the fallback + else: + return fallback_model + + +def get_scheduler_from_request( + request_data: GenerationInputData, operation: str +): + allowed = allowed_schedulers[operation] + + requested = request_data.sampler_name + requested = sampler_aliases.get(requested, requested) + + return ( + requested + if requested in allowed["schedulers"] + else allowed["fallback"] + ) + + +def get_lora_params(use_lora: str): + # TODO: since the inference functions in the webui, which we are + # still calling into for the api, jam these back together again before + # handing them off to the pipeline, we should remove this nonsense + # and unify their selection in the UI and command line args proper + if use_lora in get_custom_model_files("lora"): + return (use_lora, "") + + return ("None", use_lora) + + +def get_device(device_str: str): + # first substring match in the list available devices, with first + # device when none are matched + return next( + (device for device in available_devices if device_str in device), + available_devices[0], + ) diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index 9923fd72ad..930ea31b17 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -48,26 +48,19 @@ def launch_app(address): freeze_support() if args.api or "api" in args.ui.split(","): from apps.stable_diffusion.web.ui import ( - txt2img_api, - img2img_api, - upscaler_api, - inpaint_api, - outpaint_api, llm_chat_api, ) + from apps.stable_diffusion.web.api import sdapi from fastapi import FastAPI, APIRouter + from fastapi.middleware.cors import CORSMiddleware import uvicorn # init global sd pipeline and config global_obj._init() app = FastAPI() - app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"]) - app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"]) - app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"]) - app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"]) - app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"]) + app.mount("/sdapi/", sdapi) # chat APIs needed for compatibility with multiple extensions using OpenAI API app.add_api_route( @@ -80,6 +73,21 @@ def launch_app(address): "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] ) app.include_router(APIRouter()) + + # deal with CORS requests if CORS accept origins are set + if args.api_accept_origin: + print( + f"API Configured for CORS. Accepting origins: { args.api_accept_origin }" + ) + app.add_middleware( + CORSMiddleware, + allow_origins=args.api_accept_origin, + allow_methods=["GET", "POST"], + allow_headers=["*"], + ) + else: + print("API not configured for CORS") + uvicorn.run(app, host="0.0.0.0", port=args.server_port) sys.exit(0) diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py index 937c48f549..10cef374a1 100644 --- a/apps/stable_diffusion/web/ui/__init__.py +++ b/apps/stable_diffusion/web/ui/__init__.py @@ -1,6 +1,5 @@ from apps.stable_diffusion.web.ui.txt2img_ui import ( txt2img_inf, - txt2img_api, txt2img_web, txt2img_custom_model, txt2img_gallery, @@ -13,7 +12,6 @@ ) from apps.stable_diffusion.web.ui.img2img_ui import ( img2img_inf, - img2img_api, img2img_web, img2img_custom_model, img2img_gallery, @@ -25,7 +23,6 @@ ) from apps.stable_diffusion.web.ui.inpaint_ui import ( inpaint_inf, - inpaint_api, inpaint_web, inpaint_custom_model, inpaint_gallery, @@ -37,7 +34,6 @@ ) from apps.stable_diffusion.web.ui.outpaint_ui import ( outpaint_inf, - outpaint_api, outpaint_web, outpaint_custom_model, outpaint_gallery, @@ -49,7 +45,6 @@ ) from apps.stable_diffusion.web.ui.upscaler_ui import ( upscaler_inf, - upscaler_api, upscaler_web, upscaler_custom_model, upscaler_gallery, diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index c358955922..02ec315247 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -5,9 +5,6 @@ import PIL from math import ceil from PIL import Image -import base64 -from io import BytesIO -from fastapi.exceptions import HTTPException from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -277,87 +274,6 @@ def img2img_inf( return generated_imgs, text_output, "" -def decode_base64_to_image(encoding): - if encoding.startswith("data:image/"): - encoding = encoding.split(";", 1)[1].split(",", 1)[1] - try: - image = Image.open(BytesIO(base64.b64decode(encoding))) - return image - except Exception as err: - print(err) - raise HTTPException(status_code=500, detail="Invalid encoded image") - - -def encode_pil_to_base64(images): - encoded_imgs = [] - for image in images: - with BytesIO() as output_bytes: - if args.output_img_format.lower() == "png": - image.save(output_bytes, format="PNG") - - elif args.output_img_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG") - else: - raise HTTPException( - status_code=500, detail="Invalid image format" - ) - bytes_data = output_bytes.getvalue() - encoded_imgs.append(base64.b64encode(bytes_data)) - return encoded_imgs - - -# Img2Img Rest API. -def img2img_api( - InputData: dict, -): - print( - f'Prompt: {InputData["prompt"]}, ' - f'Negative Prompt: {InputData["negative_prompt"]}, ' - f'Seed: {InputData["seed"]}.' - ) - init_image = decode_base64_to_image(InputData["init_images"][0]) - res = img2img_inf( - InputData["prompt"], - InputData["negative_prompt"], - init_image, - InputData["height"], - InputData["width"], - InputData["steps"], - InputData["denoising_strength"], - InputData["cfg_scale"], - InputData["seed"], - batch_count=1, - batch_size=1, - scheduler="EulerDiscrete", - model_id=InputData["hf_model_id"] - if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-1-base", - custom_vae="None", - precision="fp16", - device=available_devices[0], - max_length=64, - use_stencil=InputData["use_stencil"] - if "use_stencil" in InputData.keys() - else "None", - save_metadata_to_json=False, - save_metadata_to_png=False, - lora_weights="None", - lora_hf_id="", - ondemand=False, - repeatable_seeds=False, - resample_type="Lanczos", - ) - - # Converts generator type to subscriptable - res = next(res) - - return { - "images": encode_pil_to_base64(res[0]), - "parameters": {}, - "info": res[1], - } - - with gr.Blocks(title="Image-to-Image") as img2img_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 72187d80f8..38e28cb354 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -4,9 +4,6 @@ import sys import gradio as gr from PIL import Image -import base64 -from io import BytesIO -from fastapi.exceptions import HTTPException from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -223,85 +220,6 @@ def inpaint_inf( return generated_imgs, text_output -def decode_base64_to_image(encoding): - if encoding.startswith("data:image/"): - encoding = encoding.split(";", 1)[1].split(",", 1)[1] - try: - image = Image.open(BytesIO(base64.b64decode(encoding))) - return image - except Exception as err: - print(err) - raise HTTPException(status_code=500, detail="Invalid encoded image") - - -def encode_pil_to_base64(images): - encoded_imgs = [] - for image in images: - with BytesIO() as output_bytes: - if args.output_img_format.lower() == "png": - image.save(output_bytes, format="PNG") - - elif args.output_img_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG") - else: - raise HTTPException( - status_code=500, detail="Invalid image format" - ) - bytes_data = output_bytes.getvalue() - encoded_imgs.append(base64.b64encode(bytes_data)) - return encoded_imgs - - -# Inpaint Rest API. -def inpaint_api( - InputData: dict, -): - print( - f'Prompt: {InputData["prompt"]}, ' - f'Negative Prompt: {InputData["negative_prompt"]}, ' - f'Seed: {InputData["seed"]}.' - ) - init_image = decode_base64_to_image(InputData["image"]) - mask = decode_base64_to_image(InputData["mask"]) - res = inpaint_inf( - InputData["prompt"], - InputData["negative_prompt"], - {"image": init_image, "mask": mask}, - InputData["height"], - InputData["width"], - InputData["is_full_res"], - InputData["full_res_padding"], - InputData["steps"], - InputData["cfg_scale"], - InputData["seed"], - batch_count=1, - batch_size=1, - scheduler="EulerDiscrete", - model_id=InputData["hf_model_id"] - if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-inpainting", - custom_vae="None", - precision="fp16", - device=available_devices[0], - max_length=64, - save_metadata_to_json=False, - save_metadata_to_png=False, - lora_weights="None", - lora_hf_id="", - ondemand=False, - repeatable_seeds=False, - ) - - # Converts generator type to subscriptable - res = next(res) - - return { - "images": encode_pil_to_base64(res[0]), - "parameters": {}, - "info": res[1], - } - - with gr.Blocks(title="Inpainting") as inpaint_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index 7952c47773..ea5272c9b5 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -228,87 +228,6 @@ def outpaint_inf( return generated_imgs, text_output, "" -def decode_base64_to_image(encoding): - if encoding.startswith("data:image/"): - encoding = encoding.split(";", 1)[1].split(",", 1)[1] - try: - image = Image.open(BytesIO(base64.b64decode(encoding))) - return image - except Exception as err: - print(err) - raise HTTPException(status_code=500, detail="Invalid encoded image") - - -def encode_pil_to_base64(images): - encoded_imgs = [] - for image in images: - with BytesIO() as output_bytes: - if args.output_img_format.lower() == "png": - image.save(output_bytes, format="PNG") - - elif args.output_img_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG") - else: - raise HTTPException( - status_code=500, detail="Invalid image format" - ) - bytes_data = output_bytes.getvalue() - encoded_imgs.append(base64.b64encode(bytes_data)) - return encoded_imgs - - -# Inpaint Rest API. -def outpaint_api( - InputData: dict, -): - print( - f'Prompt: {InputData["prompt"]}, ' - f'Negative Prompt: {InputData["negative_prompt"]}, ' - f'Seed: {InputData["seed"]}' - ) - init_image = decode_base64_to_image(InputData["init_images"][0]) - res = outpaint_inf( - InputData["prompt"], - InputData["negative_prompt"], - init_image, - InputData["pixels"], - InputData["mask_blur"], - InputData["directions"], - InputData["noise_q"], - InputData["color_variation"], - InputData["height"], - InputData["width"], - InputData["steps"], - InputData["cfg_scale"], - InputData["seed"], - batch_count=1, - batch_size=1, - scheduler="EulerDiscrete", - model_id=InputData["hf_model_id"] - if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-inpainting", - custom_vae="None", - precision="fp16", - device=available_devices[0], - max_length=64, - save_metadata_to_json=False, - save_metadata_to_png=False, - lora_weights="None", - lora_hf_id="", - ondemand=False, - repeatable_seeds=False, - ) - - # Convert Generator to Subscriptable - res = next(res) - - return { - "images": encode_pil_to_base64(res[0]), - "parameters": {}, - "info": res[1], - } - - with gr.Blocks(title="Outpainting") as outpaint_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index a7672f6ad3..9c7cba7fed 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -5,9 +5,6 @@ import gradio as gr from PIL import Image from math import ceil -import base64 -from io import BytesIO -from fastapi.exceptions import HTTPException from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -301,74 +298,6 @@ def txt2img_inf( return generated_imgs, text_output, "" -def encode_pil_to_base64(images): - encoded_imgs = [] - for image in images: - with BytesIO() as output_bytes: - if args.output_img_format.lower() == "png": - image.save(output_bytes, format="PNG") - - elif args.output_img_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG") - else: - raise HTTPException( - status_code=500, detail="Invalid image format" - ) - bytes_data = output_bytes.getvalue() - encoded_imgs.append(base64.b64encode(bytes_data)) - return encoded_imgs - - -# Text2Img Rest API. -def txt2img_api( - InputData: dict, -): - print( - f'Prompt: {InputData["prompt"]}, ' - f'Negative Prompt: {InputData["negative_prompt"]}, ' - f'Seed: {InputData["seed"]}.' - ) - res = txt2img_inf( - InputData["prompt"], - InputData["negative_prompt"], - InputData["height"], - InputData["width"], - InputData["steps"], - InputData["cfg_scale"], - InputData["seed"], - batch_count=1, - batch_size=1, - scheduler="EulerDiscrete", - model_id=InputData["hf_model_id"] - if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-1-base", - custom_vae="None", - precision="fp16", - device=available_devices[0], - max_length=64, - save_metadata_to_json=False, - save_metadata_to_png=False, - lora_weights="None", - lora_hf_id="", - ondemand=False, - repeatable_seeds=False, - use_hiresfix=False, - hiresfix_height=512, - hiresfix_width=512, - hiresfix_strength=0.6, - resample_type="Nearest Neighbor", - ) - - # Convert Generator to Subscriptable - res = next(res) - - return { - "images": encode_pil_to_base64(res[0]), - "parameters": {}, - "info": res[1], - } - - with gr.Blocks(title="Text-to-Image") as txt2img_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index 830287e031..6d40861e83 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -3,9 +3,6 @@ import time import gradio as gr from PIL import Image -import base64 -from io import BytesIO -from fastapi.exceptions import HTTPException from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -247,82 +244,6 @@ def upscaler_inf( yield generated_imgs, text_output, "" -def decode_base64_to_image(encoding): - if encoding.startswith("data:image/"): - encoding = encoding.split(";", 1)[1].split(",", 1)[1] - try: - image = Image.open(BytesIO(base64.b64decode(encoding))) - return image - except Exception as err: - print(err) - raise HTTPException(status_code=500, detail="Invalid encoded image") - - -def encode_pil_to_base64(images): - encoded_imgs = [] - for image in images: - with BytesIO() as output_bytes: - if args.output_img_format.lower() == "png": - image.save(output_bytes, format="PNG") - - elif args.output_img_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG") - else: - raise HTTPException( - status_code=500, detail="Invalid image format" - ) - bytes_data = output_bytes.getvalue() - encoded_imgs.append(base64.b64encode(bytes_data)) - return encoded_imgs - - -# Upscaler Rest API. -def upscaler_api( - InputData: dict, -): - print( - f'Prompt: {InputData["prompt"]}, ' - f'Negative Prompt: {InputData["negative_prompt"]}, ' - f'Seed: {InputData["seed"]}' - ) - init_image = decode_base64_to_image(InputData["init_images"][0]) - res = upscaler_inf( - InputData["prompt"], - InputData["negative_prompt"], - init_image, - InputData["height"], - InputData["width"], - InputData["steps"], - InputData["noise_level"], - InputData["cfg_scale"], - InputData["seed"], - batch_count=1, - batch_size=1, - scheduler="EulerDiscrete", - model_id=InputData["hf_model_id"] - if "hf_model_id" in InputData.keys() - else "stabilityai/stable-diffusion-2-1-base", - custom_vae="None", - precision="fp16", - device=available_devices[0], - max_length=64, - save_metadata_to_json=False, - save_metadata_to_png=False, - lora_weights="None", - lora_hf_id="", - ondemand=False, - repeatable_seeds=False, - ) - # Converts generator type to subscriptable - res = next(res) - - return { - "images": encode_pil_to_base64(res[0]), - "parameters": {}, - "info": res[1], - } - - with gr.Blocks(title="Upscaler") as upscaler_web: with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) diff --git a/docs/shark_sd_koboldcpp.md b/docs/shark_sd_koboldcpp.md new file mode 100644 index 0000000000..41ef540b40 --- /dev/null +++ b/docs/shark_sd_koboldcpp.md @@ -0,0 +1,140 @@ +# Overview + +In [1.47.2](https://github.com/LostRuins/koboldcpp/releases/tag/v1.47.2) [Koboldcpp](https://github.com/LostRuins/koboldcpp) added AUTOMATIC1111 integration for image generation. Since SHARK implements a small subset of the A1111 REST api, you can also use SHARK for this. This document gives a starting point for how to get this working. + +## In Action + +![preview](https://user-images.githubusercontent.com/121311569/280557602-bb97bad0-fdf5-4922-a2cc-4f327f2760db.jpg) + +## Memory considerations + +Since both Koboldcpp and SHARK will use VRAM on your graphic card(s) running both at the same time using the same card will impose extra limitations on the model size you can fully offload to the video card in Koboldcpp. For me, on a RX 7900 XTX on Windows with 24 GiB of VRAM, the limit was about a 13 Billion parameter model with Q5_K_M quantisation. + +## Performance Considerations + +When using SHARK for image generation, especially with Koboldcpp, you need to be aware that it is currently designed to pay a large upfront cost in time compiling and tuning the model you select, to get an optimal individual image generation time. You need to be the judge as to whether this trade-off is going to be worth it for your OS and hardware combination. + +It means that the first time you run a particular Stable Diffusion model for a particular combination of image size, LoRA, and VAE, SHARK will spend *many minutes* - even on a beefy machaine with very fast graphics card with lots of memory - building that model combination just so it can save it to disk. It may even have to go away and download the model if it doesn't already have it locally. Once it has done its build of a model combination for your hardware once, it shouldn't need to do it again until you upgrade to a newer SHARK version, install different drivers or change your graphics hardware. It will just upload the files it generated the first time to your graphics card and proceed from there. + +This does mean however, that on a brand new fresh install of SHARK that has not generated any images on a model you haven't selected before, the first image Koboldcpp requests may look like it is *never* going finish and that the whole process has broken. Be forewarned, make yourself a cup of coffee, and expect a lot of messages about compilation and tuning from SHARK in the terminal you ran it from. + +## Setup SHARK and prerequisites: + + * Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme). + * Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install. + * Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port) + +```powershell +## Run the .exe in API mode, with CORS support, on the A1111 endpoint port: +.\node_ai_shark_studio__.exe --api --api_cors_origin="*" --server_port=7860 + +## Run trom the base directory of a source clone of SHARK on Windows: +.\setup_venv.ps1 +python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860 + +## Run a the base directory of a source clone of SHARK on Linux: +./setup_venv.sh +source shark.venv/bin/activate +python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860 + +## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111 +.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 + +## Since the api respects most applicable SHARK command line arguments for options not specified, +## or currently unimplemented by API, there might be some you want to set, as listed in `--help` +.\node_ai_shark_studio_20320901_2525.exe --help + +## For instance, the example above, but with a a custom VAE specified +.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors" + +## An example with multiple specific CORS origins +python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860 +``` + +SHARK should start in server mode, and you should see something like this: + +![SHARK API startup](https://user-images.githubusercontent.com/121311569/280556294-c3f7fc1a-c8e2-467d-afe6-365638d6823a.png) + +* Note: When running in api mode with `--api`, the .exe will not function as a webUI. Thus, the address or port shown in the terminal output will only be useful for API requests. + + +## Configure Koboldcpp for local image generation: + +* Get the latest [Koboldcpp](https://github.com/LostRuins/koboldcpp/releases) if you don't already have it. If you have a recent AMD card that has ROCm HIP [support for Windows](https://rocmdocs.amd.com/en/latest/release/windows_support.html#windows-supported-gpus) or [support for Linux](https://rocmdocs.amd.com/en/latest/release/gpu_os_support.html#linux-supported-gpus), you'll likely prefer [YellowRosecx's ROCm fork](https://github.com/YellowRoseCx/koboldcpp-rocm). +* Start Koboldcpp in another terminal/Powershell and setup your model configuration. Refer to the [Koboldcpp README](https://github.com/YellowRoseCx/koboldcpp-rocm) for more details on how to do this if this is your first time using Koboldcpp. +* Once the main UI has loaded into your browser click the settings button, go to the advanced tab, and then choose *Local A1111* from the generate images dropdown: + + ![Settings button location](https://user-images.githubusercontent.com/121311569/280556246-10692d79-e89f-4fdf-87ba-82f3d78ed49d.png) + + ![Advanced Settings with 'Local A1111' location](https://user-images.githubusercontent.com/121311569/280556234-6ebc8ba7-1469-442a-93a7-5626a094ddf1.png) + + *if you get an error here, see the next section [below](#connecting-to-shark-on-a-different-address-or-port)* + +* A list of Stable Diffusion models available to your SHARK instance should now be listed in the box below *generate images*. The default value will usually be set to `stabilityai/stable-diffusion-2-1-base`. Choose the model you want to use for image generation from the list (but see [performance considerations](#performance-considerations)). +* You should now be ready to generate images, either by clicking the 'Add Img' button above the text entry box: + + ![Add Image Button](https://user-images.githubusercontent.com/121311569/280556161-846c7883-4a83-4458-a56a-bd9f93ca354c.png) + + ...or by selecting the 'Autogenerate' option in the settings: + + ![Setting the autogenerate images option](https://user-images.githubusercontent.com/121311569/280556230-ae221a46-ba68-499b-a519-c8f290bbbeae.png) + + *I often find that even if I have selected autogenerate I have to do an 'add img' to get things started off* + +* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button: + + ![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png) + + This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK: + + ![Entering extra image styles](https://github.com/one-lithe-rune/SHARK/assets/121311569/4aab9794-7a77-46d7-bdda-43df570ad19a) + + +## Connecting to SHARK on a different address or port + +If you didn't set the port to `--server_port=7860` when starting SHARK, or you are running it on different machine on your network than you are running Koboldcpp, or to where you are running the koboldcpp's kdlite client frontend, then you very likely got the following error: + + ![Can't find the A1111 endpoint error](https://user-images.githubusercontent.com/121311569/280555857-601f53dc-35e9-4027-9180-baa61d2393ba.png) + +As long as SHARK is running correctly, this means you need to set the url and port to the correct values in Koboldcpp. For instance. to set the port that Koboldcpp looks for an image generator to SHARK's default port of 8080: + +* Select the cog icon the Generate Images section of Advanced settings: + + ![Selecting the endpoint cog](https://user-images.githubusercontent.com/121311569/280555866-4287ecc5-f29f-4c03-8f5a-abeaf31b0442.png) + +* Then edit the port number at the end of the url in the 'A1111 Endpoint Selection' dialog box to read 8080: + + ![Changing the endpoint port](https://user-images.githubusercontent.com/121311569/280556170-f8848b7b-6fc9-4cf7-80eb-5c312f332fd9.png) + +* Similarly, when running SHARK on a different machine you will need to change host part of the endpoint url to the hostname or ip address where SHARK is running, similarly: + + ![Changing the endpoint hostname](https://user-images.githubusercontent.com/121311569/280556167-c6541dea-0f85-417a-b661-fdf4dc40d05f.png) + +## Examples + +Here's how Koboldcpp shows an image being requested: + + ![An image being generated]((https://user-images.githubusercontent.com/121311569/280556210-bb1c9efd-79ac-478e-b726-b25b82ef2186.png) + +The generated image in context in story mode: + + ![A generated image](https://user-images.githubusercontent.com/121311569/280556179-4e9f3752-f349-4cba-bc6a-f85f8dc79b10.jpg) + +And the same image when clicked on: + + ![A selected image](https://user-images.githubusercontent.com/121311569/280556216-2ca4c0a4-3889-4ef5-8a09-30084fb34081.jpg) + + +## Where to find the images in SHARK + +Even though Koboldcpp requests images at a size of 512x512, it resizes then to 256x256, converts them to `.jpeg`, and only shows them at 200x200 in the main text window. It does this so it can save them compactly embedded in your story as a `data://` uri. + +However the images at the original size are saved by SHARK in its `output_dir` which is usually a folder named for the current date. inside `generated_imgs` folder in the SHARK installation directory. + +You can browse these, either using the Output Gallery tab from within the SHARK web ui: + + ![SHARK web ui output gallery tab](https://user-images.githubusercontent.com/121311569/280556582-9303ca85-2594-4a8c-97a2-fbd72337980b.jpg) + +...or by browsing to the `output_dir` in your operating system's file manager: + + ![SHARK output directory subfolder in Windows File Explorer](https://user-images.githubusercontent.com/121311569/280556297-66173030-2324-415c-a236-ef3fcd73e6ed.jpg) diff --git a/rest_api_tests/api_test.py b/rest_api_tests/api_test.py index 365dc51dcf..7a4cf042c2 100644 --- a/rest_api_tests/api_test.py +++ b/rest_api_tests/api_test.py @@ -4,7 +4,7 @@ from io import BytesIO -def upscaler_test(): +def upscaler_test(verbose=False): # Define values here prompt = "" negative_prompt = "" @@ -44,10 +44,17 @@ def upscaler_test(): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print(f"response from server was : {res.status_code}") + print( + f"[upscaler] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print( + f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" + ) -def img2img_test(): +def img2img_test(verbose=False): # Define values here prompt = "Paint a rabbit riding on the dog" negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" @@ -87,7 +94,16 @@ def img2img_test(): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print(f"response from server was : {res.status_code}") + res = requests.post(url=url, json=data, headers=headers, timeout=1000) + + print( + f"[img2img] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print( + f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" + ) # NOTE Uncomment below to save the picture @@ -103,7 +119,7 @@ def img2img_test(): # response_img.save(r"rest_api_tests/response_img.png") -def inpainting_test(): +def inpainting_test(verbose=False): prompt = "Paint a rabbit riding on the dog" negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" seed = 2121991605 @@ -150,10 +166,17 @@ def inpainting_test(): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print(f"[Inpainting] response from server was : {res.status_code}") + print( + f"[inpaint] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print( + f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" + ) -def outpainting_test(): +def outpainting_test(verbose=False): prompt = "Paint a rabbit riding on the dog" negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" seed = 2121991605 @@ -200,10 +223,17 @@ def outpainting_test(): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print(f"[Outpaint] response from server was : {res.status_code}") + print( + f"[outpaint] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print( + f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" + ) -def txt2img_test(): +def txt2img_test(verbose=False): prompt = "Paint a rabbit in a top hate" negative_prompt = "ugly, bad art, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, blurry, bad anatomy, blurred, watermark, grainy, tiling, signature, cut off, draft" seed = 2121991605 @@ -232,12 +262,119 @@ def txt2img_test(): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print(f"[txt2img] response from server was : {res.status_code}") + print( + f"[txt2img] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print( + f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" + ) + + +def sd_models_test(verbose=False): + url = "http://127.0.0.1:8080/sdapi/v1/sd-models" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + res = requests.get(url=url, headers=headers, timeout=1000) + + print( + f"[sd_models] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print(f"\n{res.json() if res.status_code == 200 else res.content}\n") + + +def sd_samplers_test(verbose=False): + url = "http://127.0.0.1:8080/sdapi/v1/samplers" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + res = requests.get(url=url, headers=headers, timeout=1000) + + print( + f"[sd_samplers] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print(f"\n{res.json() if res.status_code == 200 else res.content}\n") + + +def options_test(verbose=False): + url = "http://127.0.0.1:8080/sdapi/v1/options" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + res = requests.get(url=url, headers=headers, timeout=1000) + + print( + f"[options] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print(f"\n{res.json() if res.status_code == 200 else res.content}\n") + + +def cmd_flags_test(verbose=False): + url = "http://127.0.0.1:8080/sdapi/v1/cmd-flags" + + headers = { + "User-Agent": "PythonTest", + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate, br", + } + + res = requests.get(url=url, headers=headers, timeout=1000) + + print( + f"[cmd-flags] response from server was : {res.status_code} {res.reason}" + ) + + if verbose or res.status_code != 200: + print(f"\n{res.json() if res.status_code == 200 else res.content}\n") if __name__ == "__main__": - txt2img_test() - img2img_test() - upscaler_test() - inpainting_test() - outpainting_test() + import argparse + + parser = argparse.ArgumentParser( + description=( + "Exercises the Stable Diffusion REST API of Shark. Make sure " + "Shark is running in API mode on 127.0.0.1:8080 before running" + "this script." + ), + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help=( + "also display selected info from the JSON response for " + "successful requests" + ), + ) + args = parser.parse_args() + + sd_models_test(args.verbose) + sd_samplers_test(args.verbose) + options_test(args.verbose) + cmd_flags_test(args.verbose) + txt2img_test(args.verbose) + img2img_test(args.verbose) + upscaler_test(args.verbose) + inpainting_test(args.verbose) + outpainting_test(args.verbose)