From 2cd8d493dc2485cc941cab1d0582a516c8f7982f Mon Sep 17 00:00:00 2001 From: Yorick van Pelt Date: Fri, 29 Nov 2024 20:03:57 +0100 Subject: [PATCH] Remove more code duplication --- predict.py | 132 ++++++++++++++++++++++++----------------------------- 1 file changed, 59 insertions(+), 73 deletions(-) diff --git a/predict.py b/predict.py index d74f55d..d6f1270 100644 --- a/predict.py +++ b/predict.py @@ -615,6 +615,10 @@ def shared_predict( prompt: str, num_outputs: int, num_inference_steps: int, + *, + disable_safety_checker: bool, + output_format: str, + output_quality: int, guidance: float = 3.5, # schnell ignores guidance within the model, fine to have default image: Path = None, # img2img for flux-dev prompt_strength: float = 0.8, @@ -623,8 +627,13 @@ def shared_predict( height: int = 1024, mask: Path = None, # inpainting ): + if image and go_fast: + print( + "img2img (or inpainting) not supported with fp8 quantization; running with bf16" + ) + go_fast = False if go_fast and not self.disable_fp8: - return self.fp8_predict( + imgs, np_imgs = self.fp8_predict( prompt=prompt, num_outputs=num_outputs, num_inference_steps=num_inference_steps, @@ -635,19 +644,28 @@ def shared_predict( width=width, height=height, ) - if self.disable_fp8: - print("running bf16 model, fp8 disabled") - return self.base_predict( - prompt=prompt, - num_outputs=num_outputs, - num_inference_steps=num_inference_steps, - guidance=guidance, - image=image, - prompt_strength=prompt_strength, - seed=seed, - width=width, - height=height, - mask=mask, + else: + if self.disable_fp8: + print("running bf16 model, fp8 disabled") + imgs, np_imgs = self.base_predict( + prompt=prompt, + num_outputs=num_outputs, + num_inference_steps=num_inference_steps, + guidance=guidance, + image=image, + prompt_strength=prompt_strength, + seed=seed, + width=width, + height=height, + mask=mask, + ) + + return self.postprocess( + imgs, + disable_safety_checker, + output_format, + output_quality, + np_images=np_imgs, ) @@ -674,24 +692,19 @@ def predict( megapixels: str = SHARED_INPUTS.megapixels, ) -> List[Path]: width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + return self.shared_predict( go_fast, prompt, num_outputs, num_inference_steps=num_inference_steps, + disable_safety_checker=disable_safety_checker, + output_format=output_format, + output_quality=output_quality, seed=seed, width=width, height=height, ) - return self.postprocess( - imgs, - disable_safety_checker, - output_format, - output_quality, - np_images=np_imgs, - ) - class DevPredictor(Predictor): def setup(self) -> None: @@ -728,15 +741,15 @@ def predict( go_fast: bool = SHARED_INPUTS.go_fast, megapixels: str = SHARED_INPUTS.megapixels, ) -> List[Path]: - if image and go_fast: - print("img2img not supported with fp8 quantization; running with bf16") - go_fast = False width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + return self.shared_predict( go_fast, prompt, num_outputs, - num_inference_steps, + num_inference_steps=num_inference_steps, + disable_safety_checker=disable_safety_checker, + output_format=output_format, + output_quality=output_quality, guidance=guidance, image=image, prompt_strength=prompt_strength, @@ -745,14 +758,6 @@ def predict( height=height, ) - return self.postprocess( - imgs, - disable_safety_checker, - output_format, - output_quality, - np_images=np_imgs, - ) - class SchnellLoraPredictor(Predictor): def setup(self) -> None: @@ -782,24 +787,19 @@ def predict( self.handle_loras(go_fast, lora_weights, lora_scale) width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + return self.shared_predict( go_fast, prompt, num_outputs, num_inference_steps=num_inference_steps, + disable_safety_checker=disable_safety_checker, + output_format=output_format, + output_quality=output_quality, seed=seed, width=width, height=height, ) - return self.postprocess( - imgs, - disable_safety_checker, - output_format, - output_quality, - np_images=np_imgs, - ) - class DevLoraPredictor(Predictor): def setup(self, t5=None, clip=None, ae=None) -> None: @@ -839,18 +839,17 @@ def predict( lora_scale: float = SHARED_INPUTS.lora_scale, megapixels: str = SHARED_INPUTS.megapixels, ) -> List[Path]: - if image and go_fast: - print("img2img not supported with fp8 quantization; running with bf16") - go_fast = False - self.handle_loras(go_fast, lora_weights, lora_scale) width, height = self.preprocess(aspect_ratio, megapixels) - imgs, np_imgs = self.shared_predict( + return self.shared_predict( go_fast, prompt, num_outputs, - num_inference_steps, + num_inference_steps=num_inference_steps, + disable_safety_checker=disable_safety_checker, + output_format=output_format, + output_quality=output_quality, guidance=guidance, image=image, prompt_strength=prompt_strength, @@ -859,14 +858,6 @@ def predict( height=height, ) - return self.postprocess( - imgs, - disable_safety_checker, - output_format, - output_quality, - np_images=np_imgs, - ) - class HotswapPredictor(BasePredictor): def setup(self) -> None: @@ -965,21 +956,24 @@ def predict( else: width, height = model.preprocess(aspect_ratio, megapixels=megapixels) - model.handle_loras( - go_fast, replicate_weights, lora_scale, extra_lora, extra_lora_scale - ) - if image and go_fast: print( "Img2img and inpainting not supported with fast fp8 inference; will run in bf16" ) go_fast = False - imgs, np_imgs = model.shared_predict( + model.handle_loras( + go_fast, replicate_weights, lora_scale, extra_lora, extra_lora_scale + ) + + return model.shared_predict( go_fast, prompt, num_outputs, - num_inference_steps, + num_inference_steps=num_inference_steps, + disable_safety_checker=disable_safety_checker, + output_format=output_format, + output_quality=output_quality, guidance=guidance_scale, image=image, prompt_strength=prompt_strength, @@ -989,14 +983,6 @@ def predict( mask=mask, ) - return model.postprocess( - imgs, - disable_safety_checker, - output_format, - output_quality, - np_images=np_imgs, - ) - class TestPredictor(Predictor): def setup(self) -> None: