Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed May 23, 2024
1 parent 2cfb570 commit 353b930
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 28 deletions.
15 changes: 13 additions & 2 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def __init__(
if not os.path.exists(self.weights_path):
os.mkdir(self.weights_path)

decomp_attn = True
attn_spec = None
if triple in ["gfx940", "gfx942", "gfx90a"]:
decomp_attn = False
attn_spec = "mfma"
elif triple in ["gfx1100", "gfx1103"]:
decomp_attn = False
attn_spec = "wmma"
elif target_backend == "llvm-cpu":
decomp_attn = False

self.sd_pipe = self.turbine_pipe(
hf_model_name=base_model_id,
scheduler_id=scheduler,
Expand All @@ -124,8 +135,8 @@ def __init__(
device=target_backend,
iree_target_triple=triple,
ireec_flags=EMPTY_FLAGS,
attn_spec=None,
decomp_attn=True if "gfx9" not in triple else False,
attn_spec=attn_spec,
decomp_attn=decomp_attn,
pipeline_dir=self.pipeline_dir,
external_weights_dir=self.weights_path,
external_weights=external_weights,
Expand Down
48 changes: 24 additions & 24 deletions apps/shark_studio/modules/schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,30 @@ def get_schedulers(model_id):
schedulers["DPMSolverMultistep++"] = DPMSolverMultistepScheduler.from_pretrained(
model_id, subfolder="scheduler", algorithm_type="dpmsolver++"
)
schedulers["DPMSolverMultistepKarras"] = (
DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
)
schedulers["DPMSolverMultistepKarras++"] = (
DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
use_karras_sigmas=True,
)
schedulers[
"DPMSolverMultistepKarras++"
] = DPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler",
algorithm_type="dpmsolver++",
use_karras_sigmas=True,
)
schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["EulerAncestralDiscrete"] = (
EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"EulerAncestralDiscrete"
] = EulerAncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained(
model_id,
Expand All @@ -83,11 +83,11 @@ def get_schedulers(model_id):
model_id,
subfolder="scheduler",
)
schedulers["KDPM2AncestralDiscrete"] = (
KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers[
"KDPM2AncestralDiscrete"
] = KDPM2AncestralDiscreteScheduler.from_pretrained(
model_id,
subfolder="scheduler",
)
schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained(
model_id,
Expand Down
2 changes: 1 addition & 1 deletion apps/shark_studio/web/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def write_default_sd_config(path):


def safe_name(name):
return name.replace("/", "_").replace("-", "_")
return name.split("/")[-1].replace("-", "_")


def get_path_stem(path):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ parameterized
#accelerate is now required for diffusers import from ckpt.
accelerate
ftfy
gradio==4.19.2
gradio==4.29.0
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64
Expand Down

0 comments on commit 353b930

Please sign in to comment.