Skip to content

Commit

Permalink
Merge branch 'main' into ean-studio-sdxl
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored May 23, 2024
2 parents 72cde31 + fd07cae commit 17b98e7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 40 deletions.
4 changes: 4 additions & 0 deletions apps/shark_studio/api/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def format_out(results):
self.prev_token_len = token_len + len(history)

if format_out(token) == llm_model_map[self.hf_model_name]["stop_token"]:
if (
format_out(token)
== llm_model_map["meta-llama/Llama-2-7b-chat-hf"]["stop_token"]
):
break

for i in range(len(history)):
Expand Down
70 changes: 39 additions & 31 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,14 @@ def import_original(original_img, width, height):


def base_model_changed(base_model_id):
new_choices = get_checkpoints(
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
) + get_checkpoints(model_type="checkpoints")
ckpt_path = Path(
os.path.join(
cmd_opts.model_dir, "checkpoints", os.path.basename(str(base_model_id))
)
)
ckpt_path.mkdir(parents=True, exist_ok=True)

new_choices = get_checkpoints(ckpt_path) + get_checkpoints(model_type="checkpoints")

return gr.Dropdown(
value=new_choices[0] if len(new_choices) > 0 else "None",
Expand Down Expand Up @@ -579,21 +584,6 @@ def base_model_changed(base_model_id):
object_fit="fit",
preview=True,
)
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
lines=2,
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Row():
batch_count = gr.Slider(
1,
Expand Down Expand Up @@ -629,19 +619,18 @@ def base_model_changed(base_model_id):
stop_batch = gr.Button("Stop")
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Path(get_configs_path()).mkdir(
parents=True, exist_ok=True
)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
elem_classes=["fill"],
value=view_json_file(default_config_file),
)
Path(get_configs_path()).mkdir(parents=True, exist_ok=True)
default_config_file = os.path.join(
get_configs_path(),
"default_sd_config.json",
)
write_default_sd_config(default_config_file)
sd_json = gr.JSON(
label="SD Config",
elem_classes=["fill"],
value=view_json_file(default_config_file),
render=False,
)
with gr.Row():
with gr.Column(scale=3):
load_sd_config = gr.FileExplorer(
Expand Down Expand Up @@ -704,11 +693,30 @@ def base_model_changed(base_model_id):
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Row(elem_classes=["fill"]):
sd_json.render()
save_sd_config.click(
fn=save_sd_cfg,
inputs=[sd_json, sd_config_name],
outputs=[sd_config_name],
)
with gr.Tab(label="Log", id=103) as sd_tab_log:
with gr.Row():
std_output = gr.Textbox(
value=f"{sd_model_info}\n"
f"Images will be saved at "
f"{get_generated_imgs_path()}",
elem_id="std_output",
show_label=True,
label="Log",
show_copy_button=True,
)
sd_element.load(
logger.read_sd_logs, None, std_output, every=1
)
sd_status = gr.Textbox(visible=False)
with gr.Tab(label="Automation", id=104) as sd_tab_automation:
pass

pull_kwargs = dict(
fn=pull_sd_configs,
Expand Down
12 changes: 4 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,25 +1,21 @@
-f https://download.pytorch.org/whl/nightly/cpu

--index-url https://download.pytorch.org/whl/nightly/cpu
-f https://iree.dev/pip-release-links.html
--pre

setuptools
wheel


torch>=2.3.0
shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@ean-unify-sd#subdirectory=models


# SHARK Runner
tqdm

# SHARK Downloader
google-cloud-storage

# Testing
pytest
pytest-xdist
pytest-forked
pytes
Pillow
parameterized

Expand Down
2 changes: 1 addition & 1 deletion setup_venv.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install --pre -r requirements.txt
pip install -r requirements.txt
pip install -e .

Write-Host "Source your venv with ./shark.venv/Scripts/activate"

0 comments on commit 17b98e7

Please sign in to comment.