Skip to content

Commit

Permalink
add port and adapter for worker
Browse files Browse the repository at this point in the history
 - add interface for generative ai worker
 - rename sd repository and move it to integrations
 - add genai backend in settings to choose between celery or ray
  • Loading branch information
curibe committed Aug 11, 2023
1 parent 64fe0b6 commit fa533ed
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
9 changes: 7 additions & 2 deletions morpheus-server/app/api/sdiffusion_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from morpheus_data.database import get_db
from morpheus_data.models.schemas import MagicPrompt, Prompt, PromptControlNet

from app.config import get_genai_backend
from app.error.error import ImageNotProvidedError, ModelNotFoundError, UserNotFoundError
from app.integrations.firebase import get_user
from app.models.schemas import (
Expand All @@ -14,7 +15,8 @@
from app.services.sdiffusion_services import StableDiffusionService

router = APIRouter()
sd_services = StableDiffusionService()
genai_generator = get_genai_backend()
sd_services = StableDiffusionService(genai_generator=genai_generator)


@router.post(
Expand Down Expand Up @@ -143,7 +145,10 @@ async def generate_inpaint_from_prompt_and_image_and_mask(
response_description="Perform upscaling on an image with a prompt.",
)
async def generate_upscale_from_prompt_and_image(
prompt: Prompt = Depends(), image: UploadFile = File(...), db=Depends(get_db), user=Depends(get_user)
prompt: Prompt = Depends(),
image: UploadFile = File(...),
db=Depends(get_db),
user=Depends(get_user),
):
try:
image = await image.read()
Expand Down
42 changes: 39 additions & 3 deletions morpheus-server/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ class EnvironmentEnum(str, Enum):
prod = "prod"


class GenAIBackendEnum(str, Enum):
celery = "celery"
ray = "ray"


class Settings(SettingsData):
environment: EnvironmentEnum = EnvironmentEnum.local

Expand All @@ -28,6 +33,8 @@ class Settings(SettingsData):
enable_float32: bool = False
max_num_images: int = 4

genai_backend: str = GenAIBackendEnum.celery

celery_broker_url: str = "redis://redis:6379/0"
celery_result_backend: str = "redis://redis:6379/0"

Expand Down Expand Up @@ -64,17 +71,46 @@ def read_available_samplers(file: str):
samplers = read_available_samplers("config/sd-schedulers.yaml")

file_handlers = {
"S3": {"module": "morpheus_data.repository.files.s3_files_repository", "handler": "S3ImagesRepository"}
"S3": {
"module": "morpheus_data.repository.files.s3_files_repository",
"handler": "S3ImagesRepository",
}
}

backend_handlers = {
"celery": {
"module": "app.integrations.genai_engine.sdiffusion_celery",
"handler": "GenAIStableDiffusionCelery",
}
}


@lru_cache()
def get_file_handlers():
settings = get_settings()
try:
module_import = importlib.import_module(file_handlers[settings.bucket_type]["module"])
file_handler = getattr(module_import, file_handlers[settings.bucket_type]["handler"])
module_import = importlib.import_module(
file_handlers[settings.bucket_type]["module"]
)
file_handler = getattr(
module_import, file_handlers[settings.bucket_type]["handler"]
)
return file_handler()
except Exception as e:
print("Error getting file handler", e)
return None


def get_genai_backend():
settings = get_settings()
try:
module_import = importlib.import_module(
backend_handlers[settings.genai_backend]["module"]
)
backend = getattr(
module_import, backend_handlers[settings.genai_backend]["handler"]
)
return backend()
except Exception as e:
print("Error getting genai backend", e)
return None
File renamed without changes.
42 changes: 42 additions & 0 deletions morpheus-server/app/integrations/genai_engine/genai_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod

from PIL import Image

from morpheus_data.models.schemas import MagicPrompt, Prompt, PromptControlNet


class GenAIInterface(ABC):
@staticmethod
@abstractmethod
def generate_text2img_images(prompt: Prompt) -> str:
raise NotImplementedError("This method is not implemented")

@staticmethod
@abstractmethod
def generate_img2img_images(prompt: Prompt, image: Image) -> str:
raise NotImplementedError("This method is not implemented")

@staticmethod
@abstractmethod
def generate_controlnet_images(prompt: PromptControlNet, image: Image) -> str:
raise NotImplementedError("This method is not implemented")

@staticmethod
@abstractmethod
def generate_pix2pix_images(prompt: Prompt, image: Image) -> str:
raise NotImplementedError("This method is not implemented")

@staticmethod
@abstractmethod
def generate_inpainting_images(prompt: Prompt, image: Image, mask: Image) -> str:
raise NotImplementedError("This method is not implemented")

@staticmethod
@abstractmethod
def generate_upscaling_images(prompt: Prompt, image: Image) -> str:
raise NotImplementedError("This method is not implemented")

@staticmethod
@abstractmethod
def generate_magicprompt(prompt: MagicPrompt) -> str:
raise NotImplementedError("This method is not implemented")
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
generate_stable_diffusion_text2img_output_task,
generate_stable_diffusion_upscale_output_task,
)
from app.integrations.genai_engine.genai_interface import GenAIInterface


class StableDiffusionRepository:
class GenAIStableDiffusionCelery(GenAIInterface):
@staticmethod
def generate_text2img_images(prompt: Prompt) -> str:
logger.info(f" Running Stable Diffusion Text2Img process with prompt: {prompt}")
Expand Down
20 changes: 10 additions & 10 deletions morpheus-server/app/services/sdiffusion_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,58 @@

from app.config import get_settings
from app.error.error import ImageNotProvidedError, ModelNotFoundError, UserNotFoundError
from app.repository.sdiffusion_repository import StableDiffusionRepository
from app.integrations.genai_engine.genai_interface import GenAIInterface


class StableDiffusionService:
def __init__(self):
self.sd_repository = StableDiffusionRepository()
def __init__(self, genai_generator: GenAIInterface):
self.sd_generator = genai_generator
self.model_repository = ModelRepository()
self.user_repository = UserRepository()
self.settings = get_settings()

def generate_text2img_images(self, db: Session, prompt: Prompt, email: str) -> str:
self.validate_request(db=db, model=prompt.model, email=email)
prompt.model = f"{self.settings.model_parent_path}{prompt.model}"
return self.sd_repository.generate_text2img_images(prompt)
return self.sd_generator.generate_text2img_images(prompt)

def generate_img2img_images(self, db: Session, prompt: Prompt, image: Any, email: str) -> str:
self.validate_request(db=db, model=prompt.model, email=email)
image = self.validate_and_clean_image(image=image, width=prompt.width)
prompt.model = f"{self.settings.model_parent_path}{prompt.model}"
return self.sd_repository.generate_img2img_images(prompt, image)
return self.sd_generator.generate_img2img_images(prompt, image)

def generate_controlnet_images(self, db: Session, prompt: PromptControlNet, image: Any, email: str) -> str:
self.validate_request(db=db, model=prompt.model, email=email)
image = self.validate_and_clean_image(image=image, width=prompt.width)
prompt.model = f"{self.settings.model_parent_path}{prompt.model}"
return self.sd_repository.generate_controlnet_images(prompt, image)
return self.sd_generator.generate_controlnet_images(prompt, image)

def generate_pix2pix_images(self, db: Session, prompt: Prompt, image: Any, email: str) -> str:
self.validate_request(db=db, model=prompt.model, email=email)
image = self.validate_and_clean_image(image=image, width=prompt.width)
prompt.model = f"{self.settings.model_parent_path}{prompt.model}"
return self.sd_repository.generate_pix2pix_images(prompt, image)
return self.sd_generator.generate_pix2pix_images(prompt, image)

def generate_inpainting_images(self, db: Session, prompt: Prompt, image: Any, mask: Any, email: str) -> str:
self.validate_request(db=db, model=prompt.model, email=email)
image = self.validate_and_clean_image(image=image, width=512, height=512)
mask = self.validate_and_clean_image(image=mask, width=512, height=512)
prompt.model = f"{self.settings.model_parent_path}{prompt.model}"
return self.sd_repository.generate_inpainting_images(prompt, image, mask)
return self.sd_generator.generate_inpainting_images(prompt, image, mask)

def generate_upscaling_images(self, db: Session, prompt: Prompt, image: Any, email: str) -> str:
self.validate_request(db=db, model=prompt.model, email=email)
image = self.validate_and_clean_image(image=image)
prompt.width = image.width
prompt.height = image.height
prompt.model = f"{self.settings.model_parent_path}{prompt.model}"
return self.sd_repository.generate_upscaling_images(prompt, image)
return self.sd_generator.generate_upscaling_images(prompt, image)

def generate_magicprompt(self, db: Session, prompt: MagicPrompt, email: str) -> str:
self.validate_magicprompt_request(db=db, email=email)
# prompt.config.model = f"{self.settings.model_parent_path}{prompt.config.model}"
return self.sd_repository.generate_magicprompt(prompt)
return self.sd_generator.generate_magicprompt(prompt)

def validate_request(self, db: Session, model: str, email: str) -> None:
db_user = self.user_repository.get_user_by_email(db=db, email=email)
Expand Down
3 changes: 3 additions & 0 deletions morpheus-server/secrets.env.dist
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ IMAGES_TEMP_BUCKET=
# Allowed origins for CORS
ALLOWED_ORIGINS="http://localhost:3000"

# Specify the Generative AI backend to use
GENAI_BACKEND=celery

# Environment
ENVIRONMENT=local

Expand Down

0 comments on commit fa533ed

Please sign in to comment.