Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove some code duplication #56

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 59 additions & 73 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ def shared_predict(
prompt: str,
num_outputs: int,
num_inference_steps: int,
*,
disable_safety_checker: bool,
output_format: str,
output_quality: int,
guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default
image: Path = None, # img2img for flux-dev
prompt_strength: float = 0.8,
Expand All @@ -623,8 +627,13 @@ def shared_predict(
height: int = 1024,
mask: Path = None, # inpainting
):
if image and go_fast:
print(
"img2img (or inpainting) not supported with fp8 quantization; running with bf16"
)
go_fast = False
if go_fast and not self.disable_fp8:
return self.fp8_predict(
imgs, np_imgs = self.fp8_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
Expand All @@ -635,19 +644,28 @@ def shared_predict(
width=width,
height=height,
)
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
return self.base_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
mask=mask,
else:
if self.disable_fp8:
print("running bf16 model, fp8 disabled")
imgs, np_imgs = self.base_predict(
prompt=prompt,
num_outputs=num_outputs,
num_inference_steps=num_inference_steps,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
seed=seed,
width=width,
height=height,
mask=mask,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


Expand All @@ -674,24 +692,19 @@ def predict(
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class DevPredictor(Predictor):
def setup(self) -> None:
Expand Down Expand Up @@ -728,15 +741,15 @@ def predict(
go_fast: bool = SHARED_INPUTS.go_fast,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False
width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
Expand All @@ -745,14 +758,6 @@ def predict(
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class SchnellLoraPredictor(Predictor):
def setup(self) -> None:
Expand Down Expand Up @@ -782,24 +787,19 @@ def predict(
self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
seed=seed,
width=width,
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class DevLoraPredictor(Predictor):
def setup(self, t5=None, clip=None, ae=None) -> None:
Expand Down Expand Up @@ -839,18 +839,17 @@ def predict(
lora_scale: float = SHARED_INPUTS.lora_scale,
megapixels: str = SHARED_INPUTS.megapixels,
) -> List[Path]:
if image and go_fast:
print("img2img not supported with fp8 quantization; running with bf16")
go_fast = False

self.handle_loras(go_fast, lora_weights, lora_scale)

width, height = self.preprocess(aspect_ratio, megapixels)
imgs, np_imgs = self.shared_predict(
return self.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
guidance=guidance,
image=image,
prompt_strength=prompt_strength,
Expand All @@ -859,14 +858,6 @@ def predict(
height=height,
)

return self.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class HotswapPredictor(BasePredictor):
def setup(self) -> None:
Expand Down Expand Up @@ -965,21 +956,24 @@ def predict(
else:
width, height = model.preprocess(aspect_ratio, megapixels=megapixels)

model.handle_loras(
go_fast, replicate_weights, lora_scale, extra_lora, extra_lora_scale
)

if image and go_fast:
print(
"Img2img and inpainting not supported with fast fp8 inference; will run in bf16"
)
go_fast = False

imgs, np_imgs = model.shared_predict(
model.handle_loras(
go_fast, replicate_weights, lora_scale, extra_lora, extra_lora_scale
)

return model.shared_predict(
go_fast,
prompt,
num_outputs,
num_inference_steps,
num_inference_steps=num_inference_steps,
disable_safety_checker=disable_safety_checker,
output_format=output_format,
output_quality=output_quality,
guidance=guidance_scale,
image=image,
prompt_strength=prompt_strength,
Expand All @@ -989,14 +983,6 @@ def predict(
mask=mask,
)

return model.postprocess(
imgs,
disable_safety_checker,
output_format,
output_quality,
np_images=np_imgs,
)


class TestPredictor(Predictor):
def setup(self) -> None:
Expand Down