From f692a012e1f081e7f39d21da36a44a17c89562d7 Mon Sep 17 00:00:00 2001 From: Stefan Kapusniak <121311569+one-lithe-rune@users.noreply.github.com> Date: Thu, 14 Dec 2023 20:56:37 +0000 Subject: [PATCH 1/2] UI: Fixes for Gradio 4.7.1/4.8.0 update (#2024) * Upgrade Gradio pin from 4.7.1 to 4.80. * Make Nod AI logos visible again. * Remove image toolbars from png import boxes. * Set Input Images on img2img, outpaint and upscaler tabs to be upload only. * Change Image control to an ImageEditor control for masking on the inpaint tab. Remove previous height restriction as this hides the editing controls. * Move Input Image/Masked Image on img2img, inpaint, outpaint and upscaler tabs to be the first control on their tabs. * Remove download buttons from all galleries as they download some html rather the image (gradio issue #6595) * Remove add new row and column from Output Gallery parameters dataframe. * Add partial workaround for not being able to select text in the Output Gallery Gallery parameters dataframe (gradio issue #6086 ) * Fix uglified formatting of subdirectory selection dropown, refresh button, and open folder buttons on the Output Gallery tab. * Force Output Gallery to use the full width of the Gallery control for the preview overlay when an image is selected, rather than an overlay the width of the selected image. * Fix sendto buttons. * Reset Inpaint ImageEditor control with the Mask Layer after generation is complete, as it gets lost if the image was sent to the tab from another tab rather than being uploaded. Also rework queuing and progress rendering along this codepath. This doesn't solve the underlying problem of the Mask Layer being removed, but does get inpaint fully working with the Gradio update. --- apps/stable_diffusion/web/api/sdapi_v1.py | 3 +- apps/stable_diffusion/web/index.py | 104 +++++++++------ .../web/ui/common_ui_events.py | 2 + .../web/ui/css/sd_dark_theme.css | 23 +++- apps/stable_diffusion/web/ui/img2img_ui.py | 19 +-- apps/stable_diffusion/web/ui/inpaint_ui.py | 124 +++++++++++++++--- apps/stable_diffusion/web/ui/lora_train_ui.py | 2 +- apps/stable_diffusion/web/ui/model_manager.py | 1 + apps/stable_diffusion/web/ui/outpaint_ui.py | 14 +- .../web/ui/outputgallery_ui.py | 30 +++-- .../web/ui/txt2img_sdxl_ui.py | 21 +-- apps/stable_diffusion/web/ui/txt2img_ui.py | 24 ++-- apps/stable_diffusion/web/ui/upscaler_ui.py | 16 +-- apps/stable_diffusion/web/ui/utils.py | 22 ++++ dataset/annotation_tool.py | 1 + requirements.txt | 2 +- 16 files changed, 290 insertions(+), 118 deletions(-) diff --git a/apps/stable_diffusion/web/api/sdapi_v1.py b/apps/stable_diffusion/web/api/sdapi_v1.py index 3eebd5c113..f376f0fe9d 100644 --- a/apps/stable_diffusion/web/api/sdapi_v1.py +++ b/apps/stable_diffusion/web/api/sdapi_v1.py @@ -374,7 +374,8 @@ def inpaint_api( res = inpaint_inf( InputData.prompt, InputData.negative_prompt, - {"image": init_image, "mask": mask}, + init_image, + mask, InputData.height, InputData.width, InputData.is_full_res, diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py index f301d6d246..b679586301 100644 --- a/apps/stable_diffusion/web/index.py +++ b/apps/stable_diffusion/web/index.py @@ -77,6 +77,7 @@ # It has to be in this order or gradio ignores what we've set up. from apps.stable_diffusion.web.utils.tmp_configs import ( config_tmp, + shark_tmp, ) config_tmp() @@ -86,6 +87,8 @@ from apps.stable_diffusion.web.ui.utils import ( create_custom_models_folders, nodicon_loc, + mask_editor_value_for_gallery_data, + mask_editor_value_for_image_file, ) create_custom_models_folders() @@ -177,10 +180,20 @@ def resource_path(relative_path): # init global sd pipeline and config global_obj._init() - def register_button_click(button, selectedid, inputs, outputs): + def register_sendto_click(button, selectedid, inputs, outputs): button.click( lambda x: ( - x[0]["name"] if len(x) != 0 else None, + x.root[0].image.path if len(x.root) != 0 else None, + gr.Tabs(selected=selectedid), + ), + inputs, + outputs, + ) + + def register_sendto_editor_click(button, selectedid, inputs, outputs): + button.click( + lambda x: ( + mask_editor_value_for_gallery_data(x), gr.Tabs(selected=selectedid), ), inputs, @@ -196,9 +209,12 @@ def register_modelmanager_button(button, selectedid, inputs, outputs): ), inputs, outputs, + queue=False, ) - def register_outputgallery_button(button, selectedid, inputs, outputs): + def register_outputgallery_sendto_button( + button, selectedid, inputs, outputs + ): button.click( lambda x: ( x, @@ -208,6 +224,18 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): outputs, ) + def register_outputgallery_sendto_editor_button( + button, selectedid, inputs, outputs + ): + button.click( + lambda x: ( + mask_editor_value_for_image_file(x), + gr.Tabs(selected=selectedid), + ), + inputs, + outputs, + ) + dark_theme = resource_path("ui/css/sd_dark_theme.css") with gr.Blocks( @@ -236,19 +264,6 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): if args.output_gallery: with gr.TabItem(label="Output Gallery", id=5) as og_tab: outputgallery_web.render() - - # extra output gallery configuration - outputgallery_tab_select(og_tab.select) - outputgallery_watch( - [ - txt2img_status, - img2img_status, - inpaint_status, - outpaint_status, - upscaler_status, - txt2img_sdxl_status, - ] - ) # with gr.TabItem(label="Model Manager", id=6): # model_web.render() # with gr.TabItem(label="LoRA Training (Experimental)", id=7): @@ -268,6 +283,19 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): with gr.TabItem(label="Text-to-Image (SDXL)", id=13): txt2img_sdxl_web.render() + # extra output gallery configuration + outputgallery_tab_select(og_tab.select) + outputgallery_watch( + [ + txt2img_status, + img2img_status, + inpaint_status, + outpaint_status, + upscaler_status, + txt2img_sdxl_status, + ], + ) + actual_port = app.usable_port() if actual_port != args.server_port: sd_web.load( @@ -278,134 +306,134 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): ) # send to buttons - register_button_click( + register_sendto_click( txt2img_sendto_img2img, 1, [txt2img_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( txt2img_sendto_inpaint, 2, [txt2img_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( txt2img_sendto_outpaint, 3, [txt2img_gallery], [outpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( txt2img_sendto_upscaler, 4, [txt2img_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( img2img_sendto_inpaint, 2, [img2img_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( img2img_sendto_outpaint, 3, [img2img_gallery], [outpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( img2img_sendto_upscaler, 4, [img2img_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_click( inpaint_sendto_img2img, 1, [inpaint_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_click( inpaint_sendto_outpaint, 3, [inpaint_gallery], [outpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( inpaint_sendto_upscaler, 4, [inpaint_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_click( outpaint_sendto_img2img, 1, [outpaint_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( outpaint_sendto_inpaint, 2, [outpaint_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( outpaint_sendto_upscaler, 4, [outpaint_gallery], [upscaler_init_image, tabs], ) - register_button_click( + register_sendto_click( upscaler_sendto_img2img, 1, [upscaler_gallery], [img2img_init_image, tabs], ) - register_button_click( + register_sendto_editor_click( upscaler_sendto_inpaint, 2, [upscaler_gallery], [inpaint_init_image, tabs], ) - register_button_click( + register_sendto_click( upscaler_sendto_outpaint, 3, [upscaler_gallery], [outpaint_init_image, tabs], ) if args.output_gallery: - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_txt2img, 0, [outputgallery_filename], [txt2img_png_info_img, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_img2img, 1, [outputgallery_filename], [img2img_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_editor_button( outputgallery_sendto_inpaint, 2, [outputgallery_filename], [inpaint_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_outpaint, 3, [outputgallery_filename], [outpaint_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_upscaler, 4, [outputgallery_filename], [upscaler_init_image, tabs], ) - register_outputgallery_button( + register_outputgallery_sendto_button( outputgallery_sendto_txt2img_sdxl, 0, [outputgallery_filename], diff --git a/apps/stable_diffusion/web/ui/common_ui_events.py b/apps/stable_diffusion/web/ui/common_ui_events.py index 230619b61d..f467f6b0ed 100644 --- a/apps/stable_diffusion/web/ui/common_ui_events.py +++ b/apps/stable_diffusion/web/ui/common_ui_events.py @@ -1,3 +1,5 @@ +import gradio as gr + from apps.stable_diffusion.web.ui.utils import ( HSLHue, hsl_color, diff --git a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css index 5686f0868c..fa8d50adf2 100644 --- a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css +++ b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css @@ -239,8 +239,9 @@ footer { padding: 0 !important; } -#output_subdir_container :first-child { - border: none; +#output_subdir_container { + background-color: var(--block-background-fill); + padding-right: 8px; } /* reduced animation load when generating */ @@ -279,10 +280,19 @@ footer { /* output gallery tab */ .output_parameters_dataframe table.table { - /* works around a gradio bug that always shows scrollbars */ +/* works around a gradio bug that always shows scrollbars */ overflow: clip auto; } +.output_parameters_dataframe .cell-wrap span { + /* inadequate workaround for gradio issue #6086 */ + user-select:text !important; + -moz-user-select:text !important; + -webkit-user-select:text !important; + -o-user-select:text !important; + -ms-user-select:text !important; +} + .output_parameters_dataframe tbody td { font-size: small; line-height: var(--line-xs); @@ -291,7 +301,7 @@ footer { .output_icon_button { max-width: 30px; align-self: end; - padding-bottom: 8px; + padding-bottom: 16px !important; } .outputgallery_sendto { @@ -308,6 +318,11 @@ footer { object-fit: contain !important; } +/* use the whole gallery area for previeews */ +#outputgallery_gallery .preview { + width: inherit; +} + /* centered logo for when there are no images */ #top_logo.logo_centered { height: 100%; diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py index f3522656e4..a6df246325 100644 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ b/apps/stable_diffusion/web/ui/img2img_ui.py @@ -326,14 +326,21 @@ def img2img_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + # TODO: make this import image prompt info if it exists + img2img_init_image = gr.Image( + label="Input Image", + type="pil", + interactive=True, + sources=["upload"], + ) with gr.Row(): # janky fix for overflowing text i2i_model_info = ( @@ -380,14 +387,6 @@ def img2img_inf( lines=2, elem_id="negative_prompt_box", ) - # TODO: make this import image prompt info if it exists - img2img_init_image = gr.Image( - label="Input Image", - type="pil", - height=300, - interactive=True, - ) - with gr.Accordion(label="Multistencil Options", open=False): choices = [ "None", @@ -958,6 +957,8 @@ def update_cn_input( elem_id="gallery", columns=2, object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{i2i_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py index 8cd56f452b..4ce4795a82 100644 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ b/apps/stable_diffusion/web/ui/inpaint_ui.py @@ -3,8 +3,15 @@ import time import sys import gradio as gr +import PIL.ImageOps from PIL import Image +from gradio.components.image_editor import ( + Brush, + Eraser, + EditorData, + EditorValue, +) from apps.stable_diffusion.web.ui.utils import ( available_devices, nodlogo_loc, @@ -37,11 +44,53 @@ init_import_mlir = args.import_mlir +def set_image_states(editor_data): + input_mask = editor_data["layers"][0] + + # inpaint_inf wants white mask on black background (?), whilst ImageEditor + # delivers black mask on transparent (0 opacity) background + inference_mask = Image.new( + mode="RGB", size=input_mask.size, color=(255, 255, 255) + ) + inference_mask.paste(input_mask, input_mask) + inference_mask = PIL.ImageOps.invert(inference_mask) + + return ( + # we set the ImageEditor data again, because it likes to clear + # the image layers (which include the mask) if the user hasn't + # used the upload button, and we sent it and image + # TODO: work out what is going wrong in that case so we don't have + # to do this + { + "background": editor_data["background"], + "layers": [input_mask], + "composite": None, + }, + editor_data["background"], + input_mask, + inference_mask, + ) + + +def reload_image_editor(editor_image, editor_mask): + # we set the ImageEditor data again, because it likes to clear + # the image layers (which include the mask) if the user hasn't + # used the upload button, and we sent it the image + # TODO: work out what is going wrong in that case so we don't have + # to do this + return { + "background": editor_image, + "layers": [editor_mask], + "composite": None, + } + + # Exposed to UI. def inpaint_inf( prompt: str, negative_prompt: str, - image_dict, + image, + mask_image, height: int, width: int, inpaint_full_res: bool, @@ -175,8 +224,6 @@ def inpaint_inf( start_time = time.time() global_obj.get_sd_obj().log = "" generated_imgs = [] - image = image_dict["image"] - mask_image = image_dict["mask"] text_output = "" try: seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) @@ -223,6 +270,9 @@ def inpaint_inf( with gr.Blocks(title="Inpainting") as inpaint_web: + editor_image = gr.State() + editor_mask = gr.State() + inference_mask = gr.State() with gr.Row(elem_id="ui_title"): nod_logo = Image.open(nodlogo_loc) with gr.Row(): @@ -231,14 +281,24 @@ def inpaint_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + inpaint_init_image = gr.Sketchpad( + label="Masked Image", + type="pil", + sources=("clipboard", "upload"), + interactive=True, + brush=Brush( + colors=["#000000"], + color_mode="fixed", + ), + ) with gr.Row(): # janky fix for overflowing text inpaint_model_info = ( @@ -288,14 +348,6 @@ def inpaint_inf( lines=2, elem_id="negative_prompt_box", ) - - inpaint_init_image = gr.Image( - label="Masked Image", - sources="upload", - type="pil", - height=350, - ) - with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): # janky fix for overflowing text @@ -448,6 +500,8 @@ def inpaint_inf( elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{inpaint_model_info}\n" @@ -484,7 +538,8 @@ def inpaint_inf( inputs=[ prompt, negative_prompt, - inpaint_init_image, + editor_image, + inference_mask, height, width, inpaint_full_res, @@ -514,18 +569,53 @@ def inpaint_inf( fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs), inputs=[batch_count, batch_size], outputs=inpaint_status, + show_progress="none", + ) + set_image_states_args = dict( + fn=set_image_states, + inputs=[inpaint_init_image], + outputs=[ + inpaint_init_image, + editor_image, + editor_mask, + inference_mask, + ], + show_progress="none", + ) + reload_image_editor_args = dict( + fn=reload_image_editor, + inputs=[editor_image, editor_mask], + outputs=[inpaint_init_image], + show_progress="none", ) - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs + # all these trigger generation + prompt_submit = ( + prompt.submit(**set_image_states_args) + .then(**status_kwargs) + .then(**kwargs) + .then(**reload_image_editor_args) ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) + neg_prompt_submit = ( + negative_prompt.submit(**set_image_states_args) + .then(**status_kwargs) + .then(**kwargs) + .then(**reload_image_editor_args) + ) + generate_click = ( + stable_diffusion.click(**set_image_states_args) + .then(**status_kwargs) + .then(**kwargs) + .then(**reload_image_editor_args) + ) + + # Attempts to cancel generation stop_batch.click( fn=cancel_sd, cancels=[prompt_submit, neg_prompt_submit, generate_click], ) + # Updates LoRA information when one is selected lora_weights.change( fn=lora_changed, inputs=[lora_weights], diff --git a/apps/stable_diffusion/web/ui/lora_train_ui.py b/apps/stable_diffusion/web/ui/lora_train_ui.py index d84728bc26..45c1c3e243 100644 --- a/apps/stable_diffusion/web/ui/lora_train_ui.py +++ b/apps/stable_diffusion/web/ui/lora_train_ui.py @@ -23,10 +23,10 @@ value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): diff --git a/apps/stable_diffusion/web/ui/model_manager.py b/apps/stable_diffusion/web/ui/model_manager.py index 11e01fe873..21c0939f5e 100644 --- a/apps/stable_diffusion/web/ui/model_manager.py +++ b/apps/stable_diffusion/web/ui/model_manager.py @@ -105,6 +105,7 @@ def get_image_from_model(model_json): label="Civitai Model Gallery", value=None, visible=False, + show_download_button=False, ) with gr.Row(visible=False) as sendto_btns: diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py index 2a4c0039e7..a515f6c90e 100644 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ b/apps/stable_diffusion/web/ui/outpaint_ui.py @@ -236,14 +236,17 @@ def outpaint_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + outpaint_init_image = gr.Image( + label="Input Image", type="pil", sources=["upload"] + ) with gr.Row(): outpaint_model_info = ( f"Custom Model Path: {str(get_custom_model_path())}" @@ -291,13 +294,6 @@ def outpaint_inf( lines=2, elem_id="negative_prompt_box", ) - - outpaint_init_image = gr.Image( - label="Input Image", - type="pil", - height=300, - ) - with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): # janky fix for overflowing text @@ -473,6 +469,8 @@ def outpaint_inf( elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{outpaint_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/outputgallery_ui.py b/apps/stable_diffusion/web/ui/outputgallery_ui.py index 35ef80736f..d33e5f5393 100644 --- a/apps/stable_diffusion/web/ui/outputgallery_ui.py +++ b/apps/stable_diffusion/web/ui/outputgallery_ui.py @@ -80,28 +80,28 @@ def output_subdirs() -> list[str]: label="Getting subdirectories...", value=nod_logo, interactive=False, + show_download_button=False, visible=True, show_label=True, elem_id="top_logo", elem_classes="logo_centered", - show_download_button=False, ) - gallery = gr.Gallery( label="", value=gallery_files.value, visible=False, show_label=True, columns=4, + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) with gr.Column(scale=4): with gr.Group(): - with gr.Row(): + with gr.Row(elem_id="output_subdir_container"): with gr.Column( scale=15, min_width=160, - elem_id="output_subdir_container", ): subdirectories = gr.Dropdown( label=f"Subdirectories of {output_dir}", @@ -109,7 +109,7 @@ def output_subdirs() -> list[str]: choices=subdirectory_paths.value, value="", interactive=True, - elem_classes="dropdown_no_container", + # elem_classes="dropdown_no_container", allow_custom_value=True, ) with gr.Column( @@ -149,11 +149,12 @@ def output_subdirs() -> list[str]: ) as parameters_accordian: image_parameters = gr.DataFrame( headers=["Parameter", "Value"], - col_count=2, + col_count=(2, "fixed"), + row_count=(1, "fixed"), wrap=True, elem_classes="output_parameters_dataframe", value=[["Status", "No image selected"]], - interactive=True, + interactive=False, ) with gr.Accordion(label="Send To", open=True): @@ -327,12 +328,18 @@ def on_select_image(images: list[str], evt: gr.SelectData) -> list: else: return [ filename, - list(map(list, params["parameters"].items())), + gr.DataFrame( + value=list(map(list, params["parameters"].items())), + row_count=(len(params["parameters"]), "fixed"), + ), ] return [ filename, - [["Status", "No parameters found"]], + gr.DataFrame( + value=[["Status", "No parameters found"]], + row_count=(1, "fixed"), + ), ] def on_outputgallery_filename_change(filename: str) -> list: @@ -450,11 +457,12 @@ def outputgallery_tab_select(select): # We should have been passed a list of components on other tabs that update # when a new image has generated on that tab, so set things up so the user # will see that new image if they are looking at today's subdirectory - def outputgallery_watch(components: gr.Textbox): + def outputgallery_watch(components: gr.Textbox, queued_components=[]): for component in components: component.change( on_new_image, inputs=[subdirectories, subdirectory_paths, component], outputs=[gallery_files, gallery, logo], - queue=False, + queue=component in queued_components, + show_progress="none", ) diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 85dae66d2e..807c30ad2e 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -240,10 +240,10 @@ def txt2img_sdxl_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): @@ -264,7 +264,7 @@ def txt2img_sdxl_inf( custom_checkpoint_type="sdxl" ), allow_custom_value=True, - scale=2, + scale=11, ) t2i_sdxl_vae_info = ( str(get_custom_model_path("vae")) @@ -283,15 +283,16 @@ def txt2img_sdxl_inf( ] + get_custom_model_files("vae"), allow_custom_value=True, + scale=4, + ) + txt2img_sdxl_png_info_img = gr.Image( scale=1, + label="Import PNG info", + elem_id="txt2img_prompt_image", + type="pil", + visible=True, + sources=["upload"], ) - with gr.Column(scale=1, min_width=170): - txt2img_sdxl_png_info_img = gr.Image( - label="Import PNG info", - elem_id="txt2img_prompt_image", - type="pil", - visible=True, - ) with gr.Group(elem_id="prompt_box_outer"): txt2img_sdxl_autogen = gr.Checkbox( @@ -477,6 +478,8 @@ def txt2img_sdxl_inf( elem_id="gallery", columns=[2], object_fit="scale_down", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{t2i_sdxl_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py index 9df392a90a..3b6c936cf8 100644 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_ui.py @@ -427,16 +427,16 @@ def onload_load_settings(): value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): with gr.Row(): - with gr.Column(scale=10): + with gr.Column(): with gr.Row(): t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}" txt2img_custom_model = gr.Dropdown( @@ -449,7 +449,7 @@ def onload_load_settings(): choices=get_custom_model_files() + predefined_models, allow_custom_value=True, - scale=2, + scale=11, ) # janky fix for overflowing text t2i_vae_info = ( @@ -464,16 +464,16 @@ def onload_load_settings(): choices=["None"] + get_custom_model_files("vae"), allow_custom_value=True, + scale=4, + ) + txt2img_png_info_img = gr.Image( + label="Import PNG info", + elem_id="txt2img_prompt_image", + type="pil", + visible=True, + sources=["upload"], scale=1, ) - with gr.Column(scale=1, min_width=170): - txt2img_png_info_img = gr.Image( - label="Import PNG info", - elem_id="txt2img_prompt_image", - type="pil", - visible=True, - ) - with gr.Group(elem_id="prompt_box_outer"): prompt = gr.Textbox( label="Prompt", @@ -688,6 +688,8 @@ def onload_load_settings(): elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{t2i_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py index 42157dbd98..88d0507adb 100644 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ b/apps/stable_diffusion/web/ui/upscaler_ui.py @@ -255,14 +255,19 @@ def upscaler_inf( value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=50, - show_download_button=False, ) with gr.Row(elem_id="ui_body"): with gr.Row(): with gr.Column(scale=1, min_width=600): + upscaler_init_image = gr.Image( + label="Input Image", + type="pil", + sources=["upload"], + ) with gr.Row(): upscaler_model_info = ( f"Custom Model Path: {str(get_custom_model_path())}" @@ -311,13 +316,6 @@ def upscaler_inf( lines=2, elem_id="negative_prompt_box", ) - - upscaler_init_image = gr.Image( - label="Input Image", - type="pil", - height=300, - ) - with gr.Accordion(label="LoRA Options", open=False): with gr.Row(): # janky fix for overflowing text @@ -471,6 +469,8 @@ def upscaler_inf( elem_id="gallery", columns=[2], object_fit="contain", + # TODO: Re-enable download when fixed in Gradio + show_download_button=False, ) std_output = gr.Textbox( value=f"{upscaler_model_info}\n" diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index 9252ecee9f..0572089e84 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -5,11 +5,13 @@ import json import safetensors import gradio as gr +import PIL.Image as Image from pathlib import Path from apps.stable_diffusion.src import args from dataclasses import dataclass from enum import IntEnum +from gradio.components.image_editor import EditorValue from apps.stable_diffusion.src import get_available_devices import apps.stable_diffusion.web.utils.global_obj as global_obj @@ -315,6 +317,25 @@ def default_config_exists(model_ckpt_or_id): return None +def mask_editor_value_for_image_file(filepath): + image = Image.open(filepath) + mask = Image.new(mode="RGBA", size=image.size, color=(0, 0, 0, 0)) + return {"background": image, "layers": [mask], "composite": image} + + +def mask_editor_value_for_gallery_data(gallery_data): + filepath = ( + gallery_data.root[0].image.path + if len(gallery_data.root) != 0 + else None + ) + + if os.path.isfile(filepath): + return mask_editor_value_for_image_file(filepath) + + return EditorValue() + + default_configs = { "stabilityai/sdxl-turbo": [ gr.Textbox(label="", interactive=False, value=None, visible=False), @@ -350,6 +371,7 @@ def default_config_exists(model_ckpt_or_id): ], } + nodlogo_loc = resource_path("logos/nod-logo.png") nodicon_loc = resource_path("logos/nod-icon.png") available_devices = get_available_devices() diff --git a/dataset/annotation_tool.py b/dataset/annotation_tool.py index edd088229f..8c8c85cdfd 100644 --- a/dataset/annotation_tool.py +++ b/dataset/annotation_tool.py @@ -23,6 +23,7 @@ value=nod_logo, show_label=False, interactive=False, + show_download_button=False, elem_id="top_logo", width=150, height=100, diff --git a/requirements.txt b/requirements.txt index a97baa83a3..ff649a4468 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ diffusers accelerate scipy ftfy -gradio==4.7.1 +gradio==4.8.0 altair omegaconf # 0.3.2 doesn't have binaries for arm64 From ebfcfec3389479ddc70999b04ac415ca3066b829 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Thu, 14 Dec 2023 21:44:37 -0600 Subject: [PATCH 2/2] remove shark 1.0 tests, add support for 2.0 llm * add support for external weights * add tests and edit deps --- .github/workflows/test-models.yml | 164 -------------- .github/workflows/test-studio.yml | 86 ++++++++ apps/shark_studio/api/llm.py | 163 ++++++++++---- apps/shark_studio/api/utils.py | 4 +- apps/shark_studio/tests/api_test.py | 34 +++ apps/shark_studio/web/index.py | 6 +- apps/shark_studio/web/ui/chat.py | 273 +++--------------------- build_tools/stable_diffusion_testing.py | 28 +-- dataset/annotation_tool.py | 31 +-- process_skipfiles.py | 7 +- pyproject.toml | 19 +- pytest.ini | 2 +- requirements.txt | 11 +- rest_api_tests/api_test.py | 64 ++---- setup.py | 7 - shark/iree_utils/compile_utils.py | 54 ++++- 16 files changed, 377 insertions(+), 576 deletions(-) delete mode 100644 .github/workflows/test-models.yml create mode 100644 .github/workflows/test-studio.yml create mode 100644 apps/shark_studio/tests/api_test.py diff --git a/.github/workflows/test-models.yml b/.github/workflows/test-models.yml deleted file mode 100644 index 8e9809ee41..0000000000 --- a/.github/workflows/test-models.yml +++ /dev/null @@ -1,164 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Validate Models on Shark Runtime - -on: - push: - branches: [ main ] - paths-ignore: - - '**.md' - - 'shark/examples/**' - pull_request: - branches: [ main ] - paths-ignore: - - '**.md' - - 'shark/examples/**' - workflow_dispatch: - -# Ensure that only a single job or workflow using the same -# concurrency group will run at a time. This would cancel -# any in-progress jobs in the same github workflow and github -# ref (e.g. refs/heads/main or refs/pull//merge). -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - build-validate: - strategy: - fail-fast: true - matrix: - os: [7950x, icelake, a100, MacStudio, ubuntu-latest] - suite: [cpu,cuda,vulkan] - python-version: ["3.11"] - include: - - os: ubuntu-latest - suite: lint - - os: MacStudio - suite: metal - exclude: - - os: ubuntu-latest - suite: vulkan - - os: ubuntu-latest - suite: cuda - - os: ubuntu-latest - suite: cpu - - os: MacStudio - suite: cuda - - os: MacStudio - suite: cpu - - os: MacStudio - suite: vulkan - - os: icelake - suite: vulkan - - os: icelake - suite: cuda - - os: a100 - suite: cpu - - os: 7950x - suite: cpu - - os: 7950x - suite: cuda - - runs-on: ${{ matrix.os }} - - steps: - - uses: actions/checkout@v3 - - - name: Set Environment Variables - if: matrix.os != '7950x' - run: | - echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV - echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV - - - name: Set up Python Version File ${{ matrix.python-version }} - if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake' - run: | - # See https://github.com/actions/setup-python/issues/433 - echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version - - - name: Set up Python ${{ matrix.python-version }} - if: matrix.os == 'a100' || matrix.os == 'ubuntu-latest' || matrix.os == 'icelake' - uses: actions/setup-python@v4 - with: - python-version: '${{ matrix.python-version }}' - #cache: 'pip' - #cache-dependency-path: | - # **/requirements-importer.txt - # **/requirements.txt - - - name: Install dependencies - if: matrix.suite == 'lint' - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest toml black - - - name: Lint with flake8 - if: matrix.suite == 'lint' - run: | - # black format check - black --version - black --check . - # stop the build if there are Python syntax errors or undefined names - flake8 . --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ - --statistics --exclude lit.cfg.py - - - name: Validate Models on CPU - if: matrix.suite == 'cpu' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --benchmark=native --update_tank -k cpu - gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv - gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv - python build_tools/vicuna_testing.py - - - name: Validate Models on NVIDIA GPU - if: matrix.suite == 'cuda' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --benchmark=native --update_tank -k cuda - gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv - gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv - # Disabled due to black image bug - # python build_tools/stable_diffusion_testing.py --device=cuda - - - name: Validate Vulkan Models (MacOS) - if: matrix.suite == 'metal' && matrix.os == 'MacStudio' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} ./setup_venv.sh - source shark.venv/bin/activate - echo $PATH - pip list | grep -E "torch|iree" - # disabled due to a low-visibility memory issue with pytest on macos. - # pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" --tank_url="gs://shark_tank/nightly/" -k metal - - - name: Validate Vulkan Models (a100) - if: matrix.suite == 'vulkan' && matrix.os == 'a100' - run: | - cd $GITHUB_WORKSPACE - PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh - source shark.venv/bin/activate - pytest --update_tank -k vulkan - python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail - - - name: Validate Vulkan Models (Windows) - if: matrix.suite == 'vulkan' && matrix.os == '7950x' - run: | - ./setup_venv.ps1 - pytest -k vulkan -s --ci - - - name: Validate Stable Diffusion Models (Windows) - if: matrix.suite == 'vulkan' && matrix.os == '7950x' - run: | - ./setup_venv.ps1 - python process_skipfiles.py - pyinstaller .\apps\stable_diffusion\shark_sd.spec - python build_tools/stable_diffusion_testing.py --device=vulkan diff --git a/.github/workflows/test-studio.yml b/.github/workflows/test-studio.yml new file mode 100644 index 0000000000..765a6bf761 --- /dev/null +++ b/.github/workflows/test-studio.yml @@ -0,0 +1,86 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Validate Shark Studio + +on: + push: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + pull_request: + branches: [ main ] + paths-ignore: + - '**.md' + - 'shark/examples/**' + workflow_dispatch: + +# Ensure that only a single job or workflow using the same +# concurrency group will run at a time. This would cancel +# any in-progress jobs in the same github workflow and github +# ref (e.g. refs/heads/main or refs/pull//merge). +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-validate: + strategy: + fail-fast: true + matrix: + os: [nodai-ubuntu-builder-large] + suite: [cpu] #,cuda,vulkan] + python-version: ["3.11"] + include: + - os: nodai-ubuntu-builder-large + suite: lint + + runs-on: ${{ matrix.os }} + + steps: + - uses: actions/checkout@v3 + + - name: Set Environment Variables + run: | + echo "SHORT_SHA=`git rev-parse --short=4 HEAD`" >> $GITHUB_ENV + echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV + + - name: Set up Python Version File ${{ matrix.python-version }} + run: | + echo ${{ matrix.python-version }} >> $GITHUB_WORKSPACE/.python-version + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: '${{ matrix.python-version }}' + + - name: Install dependencies + if: matrix.suite == 'lint' + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest toml black + + - name: Lint with flake8 + if: matrix.suite == 'lint' + run: | + # black format check + black --version + black --check apps/shark_studio + # stop the build if there are Python syntax errors or undefined names + flake8 . --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --isolated --count --exit-zero --max-complexity=10 --max-line-length=127 \ + --statistics --exclude lit.cfg.py + + - name: Validate Models on CPU + if: matrix.suite == 'cpu' + run: | + cd $GITHUB_WORKSPACE + python${{ matrix.python-version }} -m venv shark.venv + source shark.venv/bin/activate + pip install -r requirements.txt --no-cache-dir + pip install -e . + pip uninstall -y torch + pip install torch==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python apps/shark_studio/tests/api_test.py diff --git a/apps/shark_studio/api/llm.py b/apps/shark_studio/api/llm.py index 9e92e58cb5..1a03b817ff 100644 --- a/apps/shark_studio/api/llm.py +++ b/apps/shark_studio/api/llm.py @@ -1,9 +1,16 @@ from turbine_models.custom_models import stateless_llama -from shark.iree_utils.compile_utils import get_iree_compiled_module +import time +from shark.iree_utils.compile_utils import ( + get_iree_compiled_module, + load_vmfb_using_mmap, +) from apps.shark_studio.api.utils import get_resource_path import iree.runtime as ireert +from itertools import chain import gc +import os import torch +from transformers import AutoTokenizer llm_model_map = { "llama2_7b": { @@ -11,81 +18,161 @@ "hf_model_name": "meta-llama/Llama-2-7b-chat-hf", "stop_token": 2, "max_tokens": 4096, - } + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, + "Trelis/Llama-2-7b-chat-hf-function-calling-v2": { + "initializer": stateless_llama.export_transformer_model, + "hf_model_name": "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + "stop_token": 2, + "max_tokens": 4096, + "system_prompt": """[INST] <>Be concise. You are a helpful, respectful and honest assistant. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. <>""", + }, } class LanguageModel: def __init__( - self, model_name, hf_auth_token=None, device=None, precision="fp32" + self, + model_name, + hf_auth_token=None, + device=None, + precision="fp32", + external_weights=None, + use_system_prompt=True, ): print(llm_model_map[model_name]) self.hf_model_name = llm_model_map[model_name]["hf_model_name"] - self.torch_ir, self.tokenizer = llm_model_map[model_name][ - "initializer" - ](self.hf_model_name, hf_auth_token, compile_to="torch") self.tempfile_name = get_resource_path("llm.torch.tempfile") - with open(self.tempfile_name, "w+") as f: - f.write(self.torch_ir) - del self.torch_ir - gc.collect() - + self.vmfb_name = get_resource_path("llm.vmfb.tempfile") self.device = device self.precision = precision + self.safe_name = self.hf_model_name.strip("/").replace("/", "_") self.max_tokens = llm_model_map[model_name]["max_tokens"] self.iree_module_dict = None - self.compile() + self.external_weight_file = None + if external_weights is not None: + self.external_weight_file = get_resource_path( + self.safe_name + "." + external_weights + ) + self.use_system_prompt = use_system_prompt + self.global_iter = 0 + if os.path.exists(self.vmfb_name) and ( + external_weights is None or os.path.exists(str(self.external_weight_file)) + ): + self.iree_module_dict = dict() + ( + self.iree_module_dict["vmfb"], + self.iree_module_dict["config"], + self.iree_module_dict["temp_file_to_unlink"], + ) = load_vmfb_using_mmap( + self.vmfb_name, + device, + device_idx=0, + rt_flags=[], + external_weight_file=self.external_weight_file, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + elif not os.path.exists(self.tempfile_name): + self.torch_ir, self.tokenizer = llm_model_map[model_name]["initializer"]( + self.hf_model_name, + hf_auth_token, + compile_to="torch", + external_weights=external_weights, + external_weight_file=self.external_weight_file, + ) + with open(self.tempfile_name, "w+") as f: + f.write(self.torch_ir) + del self.torch_ir + gc.collect() + self.compile() + else: + self.tokenizer = AutoTokenizer.from_pretrained( + self.hf_model_name, + use_fast=False, + use_auth_token=hf_auth_token, + ) + self.compile() def compile(self) -> None: # this comes with keys: "vmfb", "config", and "temp_file_to_unlink". self.iree_module_dict = get_iree_compiled_module( - self.tempfile_name, device=self.device, frontend="torch" + self.tempfile_name, + device=self.device, + mmap=True, + frontend="torch", + external_weight_file=self.external_weight_file, + write_to=self.vmfb_name, ) # TODO: delete the temp file + def sanitize_prompt(self, prompt): + print(prompt) + if isinstance(prompt, list): + prompt = list(chain.from_iterable(prompt)) + prompt = " ".join([x for x in prompt if isinstance(x, str)]) + prompt = prompt.replace("\n", " ") + prompt = prompt.replace("\t", " ") + prompt = prompt.replace("\r", " ") + if self.use_system_prompt and self.global_iter == 0: + prompt = llm_model_map["llama2_7b"]["system_prompt"] + prompt + prompt += " [/INST]" + print(prompt) + return prompt + def chat(self, prompt): + prompt = self.sanitize_prompt(prompt) + + input_tensor = self.tokenizer(prompt, return_tensors="pt").input_ids + + def format_out(results): + return torch.tensor(results.to_host()[0][0]) + history = [] for iter in range(self.max_tokens): - input_tensor = self.tokenizer( - prompt, return_tensors="pt" - ).input_ids - device_inputs = [ - ireert.asdevicearray( - self.iree_module_dict["config"], input_tensor - ) - ] + st_time = time.time() if iter == 0: - token = torch.tensor( - self.iree_module_dict["vmfb"]["run_initialize"]( - *device_inputs - ).to_host()[0][0] - ) + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, input_tensor + ) + ] + token = self.iree_module_dict["vmfb"]["run_initialize"](*device_inputs) else: - token = torch.tensor( - self.iree_module_dict["vmfb"]["run_forward"]( - *device_inputs - ).to_host()[0][0] - ) + device_inputs = [ + ireert.asdevicearray( + self.iree_module_dict["config"].device, + token, + ) + ] + token = self.iree_module_dict["vmfb"]["run_forward"](*device_inputs) - history.append(token) - yield self.tokenizer.decode(history) + total_time = time.time() - st_time + history.append(format_out(token)) + yield self.tokenizer.decode(history), total_time - if token == llm_model_map["llama2_7b"]["stop_token"]: + if format_out(token) == llm_model_map["llama2_7b"]["stop_token"]: break for i in range(len(history)): if type(history[i]) != int: history[i] = int(history[i]) result_output = self.tokenizer.decode(history) - yield result_output + self.global_iter += 1 + return result_output, total_time if __name__ == "__main__": lm = LanguageModel( - "llama2_7b", - hf_auth_token="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", + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, device="cpu-task", + external_weights="safetensors", ) + print("model loaded") - for i in lm.chat("Hello, I am a robot."): + for i in lm.chat("hi, what are you?"): print(i) diff --git a/apps/shark_studio/api/utils.py b/apps/shark_studio/api/utils.py index bb5e150364..4072491cbf 100644 --- a/apps/shark_studio/api/utils.py +++ b/apps/shark_studio/api/utils.py @@ -8,7 +8,5 @@ def get_available_devices(): def get_resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) diff --git a/apps/shark_studio/tests/api_test.py b/apps/shark_studio/tests/api_test.py new file mode 100644 index 0000000000..c88a1e70cb --- /dev/null +++ b/apps/shark_studio/tests/api_test.py @@ -0,0 +1,34 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import unittest +from apps.shark_studio.api.llm import LanguageModel + + +class LLMAPITest(unittest.TestCase): + def testLLMSimple(self): + lm = LanguageModel( + "Trelis/Llama-2-7b-chat-hf-function-calling-v2", + hf_auth_token=None, + device="cpu-task", + external_weights="safetensors", + ) + count = 0 + for msg, _ in lm.chat("hi, what are you?"): + # skip first token output + if count == 0: + count += 1 + continue + assert ( + msg.strip(" ") == "Hello" + ), f"LLM API failed to return correct response, expected 'Hello', received {msg}" + break + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/apps/shark_studio/web/index.py b/apps/shark_studio/web/index.py index 59b66bee23..3ef6bc5739 100644 --- a/apps/shark_studio/web/index.py +++ b/apps/shark_studio/web/index.py @@ -93,9 +93,7 @@ def launch_app(address): def resource_path(relative_path): """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) + base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_path, relative_path) dark_theme = resource_path("ui/css/sd_dark_theme.css") @@ -201,7 +199,7 @@ def register_outputgallery_button(button, selectedid, inputs, outputs): ) with gr.Blocks( - css=dark_theme, analytics_enabled=False, title="Stable Diffusion" + css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta" ) as sd_web: with gr.Tabs() as tabs: # NOTE: If adding, removing, or re-ordering tabs, make sure that they diff --git a/apps/shark_studio/web/ui/chat.py b/apps/shark_studio/web/ui/chat.py index dd1c2d94e3..4726eef6e8 100644 --- a/apps/shark_studio/web/ui/chat.py +++ b/apps/shark_studio/web/ui/chat.py @@ -1,4 +1,5 @@ import gradio as gr +import time import os from pathlib import Path from datetime import datetime as dt @@ -21,104 +22,12 @@ def user(message, history): language_model = None -# NOTE: Each `model_name` should have its own start message -start_message = { - "llama2_7b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_13b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_70b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "vicuna": ( - "A chat between a curious user and an artificial intelligence " - "assistant. The assistant gives helpful, detailed, and " - "polite answers to the user's questions.\n" - ), -} - - def create_prompt(model_name, history, prompt_prefix): return "" - system_message = "" - if prompt_prefix: - system_message = start_message[model_name] - - if "llama2" in model_name: - B_INST, E_INST = "[INST]", "[/INST]" - B_SYS, E_SYS = "<>\n", "\n<>\n\n" - conversation = "".join( - [f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]] - ) - if prompt_prefix: - msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}" - else: - msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}" - elif model_name in ["vicuna"]: - conversation = "".join( - [ - "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) - for item in history - ] - ) - msg = system_message + conversation - msg = msg.strip() - else: - conversation = "".join( - ["".join([item[0], item[1]]) for item in history] - ) - msg = system_message + conversation - msg = msg.strip() - return msg def get_default_config(): return False - import torch - from transformers import AutoTokenizer - - hf_model_path = "TheBloke/vicuna-7B-1.1-HF" - tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) - firstVicunaCompileInput = (compilation_input_ids,) - from apps.language_models.src.model_wrappers.vicuna_model import ( - CombinedModel, - ) - from shark.shark_generate_model_config import GenerateConfigFile - - model = CombinedModel() - c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - c.split_into_layers() # model_vmfb_key = "" @@ -133,153 +42,37 @@ def chat_fn( download_vmfb, config_file, cli=False, - progress=gr.Progress(), ): global language_model if language_model is None: + history[-1][-1] = "Getting the model ready..." + yield history, "" language_model = LanguageModel( - model, device=device, precision=precision - ) - - language_model.chat(prompt_prefix) - return "", "" - global past_key_values - global model_vmfb_key - - device_id = None - model_name, model_path = list(map(str.strip, model.split("=>"))) - if "cuda" in device: - device = "cuda" - elif "sync" in device: - device = "cpu-sync" - elif "task" in device: - device = "cpu-task" - elif "vulkan" in device: - device_id = int(device.split("://")[1]) - device = "vulkan" - elif "rocm" in device: - device = "rocm" - else: - print("unrecognized device") - - from apps.language_models.scripts.vicuna import ShardedVicuna - from apps.language_models.scripts.vicuna import UnshardedVicuna - from apps.stable_diffusion.src import args - - new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}" - if vicuna_model is None or new_model_vmfb_key != model_vmfb_key: - model_vmfb_key = new_model_vmfb_key - max_toks = 128 if model_name == "codegen" else 512 - - # get iree flags that need to be overridden, from commandline args - _extra_args = [] - # vulkan target triple - vulkan_target_triple = args.iree_vulkan_target_triple - from shark.iree_utils.vulkan_utils import ( - get_all_vulkan_devices, - get_vulkan_target_triple, + model, + device=device, + precision=precision, + external_weights="safetensors", + external_weight_file="llama2_7b.safetensors", + use_system_prompt=prompt_prefix, ) - - if device == "vulkan": - vulkaninfo_list = get_all_vulkan_devices() - if vulkan_target_triple == "": - # We already have the device_id extracted via WebUI, so we directly use - # that to find the target triple. - vulkan_target_triple = get_vulkan_target_triple( - vulkaninfo_list[device_id] - ) - _extra_args.append( - f"-iree-vulkan-target-triple={vulkan_target_triple}" - ) - if "rdna" in vulkan_target_triple: - flags_to_add = [ - "--iree-spirv-index-bits=64", - ] - _extra_args = _extra_args + flags_to_add - - if device_id is None: - id = 0 - for device in vulkaninfo_list: - target_triple = get_vulkan_target_triple( - vulkaninfo_list[id] - ) - if target_triple == vulkan_target_triple: - device_id = id - break - id += 1 - - assert ( - device_id - ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" - print(f"Will use vulkan target triple : {vulkan_target_triple}") - - elif "rocm" in device: - # add iree rocm flags - _extra_args.append( - f"--iree-rocm-target-chip={args.iree_rocm_target_chip}" - ) - print(f"extra args = {_extra_args}") - - if model_name == "vicuna4": - vicuna_model = ShardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - compressed=True, - extra_args_cmd=_extra_args, - ) - else: - # if config_file is None: - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, - device=device, - vulkan_target_triple=vulkan_target_triple, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=download_vmfb, - load_mlir_from_shark_tank=True, - extra_args_cmd=_extra_args, - device_id=device_id, - ) - - if vicuna_model is None: - sys.exit("Unable to instantiate the model object, exiting.") - - prompt = create_prompt(model_name, history, prompt_prefix) - - partial_text = "" + history[-1][-1] = "Getting the model ready... Done" + yield history, "" + history[-1][-1] = "" token_count = 0 - total_time_ms = 0.001 # In order to avoid divide by zero error + total_time = 0.001 # In order to avoid divide by zero error prefill_time = 0 is_first = True - for text, msg, exec_time in progress.tqdm( - vicuna_model.generate(prompt, cli=cli), - desc="generating response", - ): - if msg is None: - if is_first: - prefill_time = exec_time - is_first = False - else: - total_time_ms += exec_time - token_count += 1 - partial_text += text + " " - history[-1][1] = partial_text + for text, exec_time in language_model.chat(history): + history[-1][-1] = text + if is_first: + prefill_time = exec_time + is_first = False yield history, f"Prefill: {prefill_time:.2f}" - elif "formatted" in msg: - history[-1][1] = text - tokens_per_sec = (token_count / total_time_ms) * 1000 - yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" else: - sys.exit( - "unexpected message from the vicuna generate call, exiting." - ) - - return history, "" + total_time += exec_time + token_count += 1 + tokens_per_sec = token_count / total_time + yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" def llm_chat_api(InputData: dict): @@ -297,17 +90,11 @@ def llm_chat_api(InputData: dict): # print(f"prompt : {InputData['prompt']}") # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now global vicuna_model - model_name = ( - InputData["model"] if "model" in InputData.keys() else "codegen" - ) + model_name = InputData["model"] if "model" in InputData.keys() else "codegen" model_path = llm_model_map[model_name] device = "cpu-task" precision = "fp16" - max_toks = ( - None - if "max_tokens" not in InputData.keys() - else InputData["max_tokens"] - ) + max_toks = None if "max_tokens" not in InputData.keys() else InputData["max_tokens"] if max_toks is None: max_toks = 128 if model_name == "codegen" else 512 @@ -344,9 +131,7 @@ def llm_chat_api(InputData: dict): # TODO: add role dict for different models if is_chat_completion_api: # TODO: add funtionality for multiple messages - prompt = create_prompt( - model_name, [(InputData["messages"][0]["content"], "")] - ) + prompt = create_prompt(model_name, [(InputData["messages"][0]["content"], "")]) else: prompt = InputData["prompt"] print("prompt = ", prompt) @@ -379,9 +164,7 @@ def llm_chat_api(InputData: dict): end_time = dt.now().strftime("%Y%m%d%H%M%S%f") return { "id": end_time, - "object": "chat.completion" - if is_chat_completion_api - else "text_completion", + "object": "chat.completion" if is_chat_completion_api else "text_completion", "created": int(end_time), "choices": choices, } @@ -457,9 +240,7 @@ def view_json_file(file_obj): with gr.Row(visible=False): with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) + config_file = gr.File(label="Upload sharding configuration", visible=False) json_view_button = gr.Button(label="View as JSON", visible=False) json_view = gr.JSON(interactive=True, visible=False) json_view_button.click( diff --git a/build_tools/stable_diffusion_testing.py b/build_tools/stable_diffusion_testing.py index ced919732c..8eeb1a7395 100644 --- a/build_tools/stable_diffusion_testing.py +++ b/build_tools/stable_diffusion_testing.py @@ -36,9 +36,7 @@ def parse_sd_out(filename, command, device, use_tune, model_name, import_mlir): metrics[val] = line.split(" ")[-1].strip("\n") metrics["Average step"] = metrics["Average step"].strip("ms/it") - metrics["Total image generation"] = metrics[ - "Total image generation" - ].strip("sec") + metrics["Total image generation"] = metrics["Total image generation"].strip("sec") metrics["device"] = device metrics["use_tune"] = use_tune metrics["model_name"] = model_name @@ -84,10 +82,14 @@ def test_loop( ] import_options = ["--import_mlir", "--no-import_mlir"] prompt_text = "--prompt=cyberpunk forest by Salvador Dali" - inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + inpaint_prompt_text = ( + "--prompt=Face of a yellow cat, high resolution, sitting on a park bench" + ) if os.name == "nt": prompt_text = '--prompt="cyberpunk forest by Salvador Dali"' - inpaint_prompt_text = '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + inpaint_prompt_text = ( + '--prompt="Face of a yellow cat, high resolution, sitting on a park bench"' + ) if beta: extra_flags.append("--beta_models=True") extra_flags.append("--no-progress_bar") @@ -174,9 +176,7 @@ def test_loop( ) print(command) print("Successfully generated image") - os.makedirs( - "./test_images/golden/" + model_name, exist_ok=True - ) + os.makedirs("./test_images/golden/" + model_name, exist_ok=True) download_public_file( "gs://shark_tank/testdata/golden/" + model_name, "./test_images/golden/" + model_name, @@ -191,14 +191,10 @@ def test_loop( ) test_file = glob(test_file_path)[0] - golden_path = ( - "./test_images/golden/" + model_name + "/*.png" - ) + golden_path = "./test_images/golden/" + model_name + "/*.png" golden_file = glob(golden_path)[0] try: - compare_images( - test_file, golden_file, upload=upload_bool - ) + compare_images(test_file, golden_file, upload=upload_bool) except AssertionError as e: print(e) if exit_on_fail == True: @@ -267,9 +263,7 @@ def prepare_artifacts(): parser.add_argument( "-x", "--exit_on_fail", action=argparse.BooleanOptionalAction, default=True ) -parser.add_argument( - "-g", "--gen", action=argparse.BooleanOptionalAction, default=False -) +parser.add_argument("-g", "--gen", action=argparse.BooleanOptionalAction, default=False) if __name__ == "__main__": args = parser.parse_args() diff --git a/dataset/annotation_tool.py b/dataset/annotation_tool.py index 8c8c85cdfd..60f607146d 100644 --- a/dataset/annotation_tool.py +++ b/dataset/annotation_tool.py @@ -10,9 +10,7 @@ shark_root = Path(__file__).parent.parent demo_css = shark_root.joinpath("web/demo.css").resolve() -nodlogo_loc = shark_root.joinpath( - "web/models/stable_diffusion/logos/nod-logo.png" -) +nodlogo_loc = shark_root.joinpath("web/models/stable_diffusion/logos/nod-logo.png") with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web: @@ -76,9 +74,7 @@ def filter_datasets(dataset): with jsonlines.open(dataset_path + "/metadata.jsonl") as reader: for line in reader.iter(type=dict, skip_invalid=True): prompt_data[line["file_name"]] = ( - [line["text"]] - if type(line["text"]) is str - else line["text"] + [line["text"]] if type(line["text"]) is str else line["text"] ) return gr.Dropdown.update(choices=images[dataset]) @@ -104,9 +100,7 @@ def display_image(dataset, image_name): prompt_data[image_name] = [] prompt_choices = ["Add new"] prompt_choices += prompt_data[image_name] - return gr.Image.update(value=img), gr.Dropdown.update( - choices=prompt_choices - ) + return gr.Image.update(value=img), gr.Dropdown.update(choices=prompt_choices) image_name.change( fn=display_image, @@ -123,12 +117,7 @@ def edit_prompt(prompts): prompts.change(fn=edit_prompt, inputs=prompts, outputs=prompt) def save_prompt(dataset, image_name, prompts, prompt): - if ( - dataset is None - or image_name is None - or prompts is None - or prompt is None - ): + if dataset is None or image_name is None or prompts is None or prompt is None: return if prompts == "Add new": @@ -137,9 +126,7 @@ def save_prompt(dataset, image_name, prompts, prompt): idx = prompt_data[image_name].index(prompts) prompt_data[image_name][idx] = prompt - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -166,9 +153,7 @@ def delete_prompt(dataset, image_name, prompts): return prompt_data[image_name].remove(prompts) - prompt_path = ( - str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" - ) + prompt_path = str(shark_root) + "/dataset/" + dataset + "/metadata.jsonl" # write prompt jsonlines file with open(prompt_path, "w") as f: for key, value in prompt_data.items(): @@ -231,9 +216,7 @@ def finish_annotation(dataset): # upload prompt and remove local data dataset_path = str(shark_root) + "/dataset/" + dataset dataset_gs_path = args.gs_url + "/" + dataset + "/" - os.system( - f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"' - ) + os.system(f'gsutil cp "{dataset_path}/metadata.jsonl" "{dataset_gs_path}"') os.system(f'rm -rf "{dataset_path}"') return gr.Dropdown.update(value=None) diff --git a/process_skipfiles.py b/process_skipfiles.py index a846159451..339c7ebec6 100644 --- a/process_skipfiles.py +++ b/process_skipfiles.py @@ -8,8 +8,7 @@ # Temporary workaround for transformers/__init__.py. path_to_transformers_hook = Path( - get_python_lib() - + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" + get_python_lib() + "/_pyinstaller_hooks_contrib/hooks/stdhooks/hook-transformers.py" ) if path_to_transformers_hook.is_file(): pass @@ -59,9 +58,7 @@ # For getting around timm's packaging. # Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505 -path_to_timm_activations = Path( - get_python_lib() + "/timm/layers/activations_jit.py" -) +path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py") for line in fileinput.input(path_to_timm_activations, inplace=True): if "@torch.jit.script" in line: print("@torch.jit._script_if_tracing", end="\n") diff --git a/pyproject.toml b/pyproject.toml index 22e0210c50..876df2f8bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,14 +5,25 @@ requires = [ "packaging", "numpy>=1.22.4", - "torch-mlir>=20230620.875", "iree-compiler>=20221022.190", "iree-runtime>=20221022.190", ] build-backend = "setuptools.build_meta" [tool.black] -line-length = 79 include = '\.pyi?$' -exclude = "apps/language_models/scripts/vicuna.py" -extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py" +exclude = ''' +( + /( + | apps/stable_diffusion + | apps/language_models + | shark + | benchmarks + | tank + | build + | generated_imgs + | shark.venv + )/ + | setup.py +) +''' diff --git a/pytest.ini b/pytest.ini index 11f57888b2..3857248785 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,3 +1,3 @@ [pytest] addopts = --verbose -s -p no:warnings -norecursedirs = inference tank/tflite examples benchmarks shark +norecursedirs = inference tank/tflite examples benchmarks shark apps/shark_studio diff --git a/requirements.txt b/requirements.txt index ff649a4468..3f7e719e67 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,13 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html +-f https://openxla.github.io/iree/pip-release-links.html --pre setuptools wheel +shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main +turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=python/turbine_models + # SHARK Runner tqdm @@ -17,11 +21,7 @@ pytest-forked Pillow parameterized -#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main # Add transformers, diffusers and scipy since it most commonly used -tokenizers==0.13.3 -transformers -diffusers #accelerate is now required for diffusers import from ckpt. accelerate scipy @@ -49,9 +49,6 @@ pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions pefile pyinstaller -# vicuna quantization -brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea - # For quantized GPTQ models optimum auto_gptq diff --git a/rest_api_tests/api_test.py b/rest_api_tests/api_test.py index 7a4cf042c2..f3c0b0e170 100644 --- a/rest_api_tests/api_test.py +++ b/rest_api_tests/api_test.py @@ -44,14 +44,10 @@ def upscaler_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[upscaler] response from server was : {res.status_code} {res.reason}" - ) + print(f"[upscaler] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def img2img_test(verbose=False): @@ -96,14 +92,10 @@ def img2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[img2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[img2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") # NOTE Uncomment below to save the picture @@ -133,13 +125,9 @@ def inpainting_test(verbose=False): image_path = r"./rest_api_tests/dog.png" img_file = open(image_path, "rb") - image = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + image = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() img_file = open(image_path, "rb") - mask = ( - "data:image/png;base64," + base64.b64encode(img_file.read()).decode() - ) + mask = "data:image/png;base64," + base64.b64encode(img_file.read()).decode() url = "http://127.0.0.1:8080/sdapi/v1/inpaint" @@ -166,14 +154,10 @@ def inpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[inpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[inpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def outpainting_test(verbose=False): @@ -223,14 +207,10 @@ def outpainting_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[outpaint] response from server was : {res.status_code} {res.reason}" - ) + print(f"[outpaint] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def txt2img_test(verbose=False): @@ -262,14 +242,10 @@ def txt2img_test(verbose=False): res = requests.post(url=url, json=data, headers=headers, timeout=1000) - print( - f"[txt2img] response from server was : {res.status_code} {res.reason}" - ) + print(f"[txt2img] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: - print( - f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n" - ) + print(f"\n{res.json()['info'] if res.status_code == 200 else res.content}\n") def sd_models_test(verbose=False): @@ -283,9 +259,7 @@ def sd_models_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_models] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_models] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -302,9 +276,7 @@ def sd_samplers_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[sd_samplers] response from server was : {res.status_code} {res.reason}" - ) + print(f"[sd_samplers] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -321,9 +293,7 @@ def options_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[options] response from server was : {res.status_code} {res.reason}" - ) + print(f"[options] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") @@ -340,9 +310,7 @@ def cmd_flags_test(verbose=False): res = requests.get(url=url, headers=headers, timeout=1000) - print( - f"[cmd-flags] response from server was : {res.status_code} {res.reason}" - ) + print(f"[cmd-flags] response from server was : {res.status_code} {res.reason}") if verbose or res.status_code != 200: print(f"\n{res.json() if res.status_code == 200 else res.content}\n") diff --git a/setup.py b/setup.py index c387fe9add..061873e7a8 100644 --- a/setup.py +++ b/setup.py @@ -9,11 +9,6 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5" backend_deps = [] -if "NO_BACKEND" in os.environ.keys(): - backend_deps = [ - "iree-compiler>=20221022.190", - "iree-runtime>=20221022.190", - ] setup( name="nodai-SHARK", @@ -39,7 +34,5 @@ install_requires=[ "numpy", "PyYAML", - "torch-mlir", ] - + backend_deps, ) diff --git a/shark/iree_utils/compile_utils.py b/shark/iree_utils/compile_utils.py index 6cfe369426..bae1908e1c 100644 --- a/shark/iree_utils/compile_utils.py +++ b/shark/iree_utils/compile_utils.py @@ -305,6 +305,7 @@ def compile_module_to_flatbuffer( model_name="None", debug=False, compile_str=False, + write_to=None, ): # Setup Compile arguments wrt to frontends. input_type = "auto" @@ -342,12 +343,24 @@ def compile_module_to_flatbuffer( extra_args=args, ) + if write_to is not None: + with open(write_to, "wb") as f: + f.write(flatbuffer_blob) + return None + return flatbuffer_blob def get_iree_module( - flatbuffer_blob, device, device_idx=None, rt_flags: list = [] + flatbuffer_blob, + device, + device_idx=None, + rt_flags: list = [], + external_weight_file=None, ): + if external_weight_file is not None: + index = ireert.ParameterIndex() + index.load(external_weight_file) # Returns the compiled module and the configs. for flag in rt_flags: ireert.flags.parse_flag(flag) @@ -369,7 +382,10 @@ def get_iree_module( vm_module = ireert.VmModule.from_buffer( config.vm_instance, flatbuffer_blob, warn_if_copy=False ) - ctx = ireert.SystemContext(config=config) + modules = [] + if external_weight_file is not None: + modules.append(index.create_provider(scope="model")) + ctx = ireert.SystemContext(vm_modules=modules, config=config) ctx.add_vm_module(vm_module) ModuleCompiled = getattr(ctx.modules, vm_module.name) return ModuleCompiled, config @@ -380,6 +396,7 @@ def load_vmfb_using_mmap( device: str, device_idx: int = None, rt_flags: list = [], + external_weight_file: str = None, ): print(f"Loading module {flatbuffer_blob_or_path}...") if "task" in device: @@ -440,17 +457,28 @@ def load_vmfb_using_mmap( mmaped_vmfb = ireert.VmModule.mmap( config.vm_instance, flatbuffer_blob_or_path ) + vm_modules = [] + if external_weight_file is not None: + index = ireert.ParameterIndex() + index.load(external_weight_file) + param_module = ireert.create_io_parameters_module( + config.vm_instance, index.create_provider(scope="model") + ) + vm_modules.append(param_module) + vm_modules.append(mmaped_vmfb) + vm_modules.append( + ireert.create_hal_module(config.vm_instance, config.device) + ) dl.log(f"mmap {flatbuffer_blob_or_path}") - ctx = ireert.SystemContext(config=config) - for flag in shark_args.additional_runtime_args: - ireert.flags.parse_flags(flag) - dl.log(f"ireert.SystemContext created") if "vulkan" in device: # Vulkan pipeline creation consumes significant amount of time. print( "\tCompiling Vulkan shaders. This may take a few minutes." ) - ctx.add_vm_module(mmaped_vmfb) + ctx = ireert.SystemContext(config=config, vm_modules=vm_modules) + dl.log(f"ireert.SystemContext created") + for flag in shark_args.additional_runtime_args: + ireert.flags.parse_flags(flag) dl.log(f"module initialized") mmaped_vmfb = getattr(ctx.modules, mmaped_vmfb.name) else: @@ -475,6 +503,8 @@ def get_iree_compiled_module( mmap: bool = False, debug: bool = False, compile_str: bool = False, + external_weight_file: str = None, + write_to: bool = None, ): """Given a module returns the compiled .vmfb and configs""" flatbuffer_blob = compile_module_to_flatbuffer( @@ -485,6 +515,7 @@ def get_iree_compiled_module( extra_args=extra_args, debug=debug, compile_str=compile_str, + write_to=write_to, ) temp_file_to_unlink = None # TODO: Currently mmap=True control flow path has been switched off for mmap. @@ -492,8 +523,14 @@ def get_iree_compiled_module( # we're setting delete=False when creating NamedTemporaryFile. That's why # I'm getting hold of the name of the temporary file in `temp_file_to_unlink`. if mmap: + if write_to is not None: + flatbuffer_blob = write_to vmfb, config, temp_file_to_unlink = load_vmfb_using_mmap( - flatbuffer_blob, device, device_idx, rt_flags + flatbuffer_blob, + device, + device_idx, + rt_flags, + external_weight_file=external_weight_file, ) else: vmfb, config = get_iree_module( @@ -501,6 +538,7 @@ def get_iree_compiled_module( device, device_idx=device_idx, rt_flags=rt_flags, + external_weight_file=external_weight_file, ) ret_params = { "vmfb": vmfb,