diff --git a/apps/shark_studio/api/sd.py b/apps/shark_studio/api/sd.py index d70b60c7e7..e0534db5de 100644 --- a/apps/shark_studio/api/sd.py +++ b/apps/shark_studio/api/sd.py @@ -104,7 +104,7 @@ def __init__( self.base_model_id = base_model_id self.custom_vae = custom_vae self.is_sdxl = "xl" in self.base_model_id.lower() - self.is_custom = "custom" in self.base_model_id.lower() + self.is_custom = ".py" in self.base_model_id.lower() if self.is_custom: custom_module = load_script( os.path.join(get_checkpoints_path("scripts"), self.base_model_id), @@ -112,8 +112,7 @@ def __init__( ) self.turbine_pipe = custom_module.StudioPipeline self.model_map = custom_module.MODEL_MAP - - if self.is_sdxl: + elif self.is_sdxl: self.turbine_pipe = SharkSDXLPipeline self.model_map = EMPTY_SDXL_MAP else: