diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index 699e9fe2a..44a05749d 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -3,7 +3,6 @@ import os.path import tempfile import typing -from enum import Enum import gooey_gui as gui import requests @@ -14,6 +13,7 @@ from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings from daras_ai_v2.azure_asr import azure_asr +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import ( raise_for_status, UserError, @@ -204,7 +204,7 @@ GHANA_NLP_MAXLEN = 500 -class AsrModels(Enum): +class AsrModels(GooeyEnum): whisper_large_v2 = "Whisper Large v2 (openai)" whisper_large_v3 = "Whisper Large v3 (openai)" whisper_hindi_large_v2 = "Whisper Hindi Large v2 (Bhashini)" @@ -277,7 +277,7 @@ class AsrOutputJson(typing_extensions.TypedDict): chunks: typing_extensions.NotRequired[list[AsrChunk]] -class AsrOutputFormat(Enum): +class AsrOutputFormat(GooeyEnum): text = "Text" json = "JSON" srt = "SRT" @@ -290,7 +290,7 @@ class TranslationModel(typing.NamedTuple): supports_auto_detect: bool = False -class TranslationModels(TranslationModel, Enum): +class TranslationModels(TranslationModel, GooeyEnum): google = TranslationModel( label="Google Translate", supports_glossary=True, diff --git a/daras_ai_v2/base.py b/daras_ai_v2/base.py index 70c294a92..0874ba64f 100644 --- a/daras_ai_v2/base.py +++ b/daras_ai_v2/base.py @@ -127,6 +127,7 @@ class BasePage: title: str workflow: Workflow slug_versions: list[str] + sdk_method_name: str sane_defaults: dict = {} @@ -144,6 +145,9 @@ class BasePage: ) class RequestModel(BaseModel): + class Config: + use_enum_values = True + functions: list[RecipeFunction] | None = Field( title="🧩 Developer Tools and Functions", ) @@ -317,6 +321,12 @@ def sentry_event_set_user(self, event, hint): } return event + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return { + "x-fern-sdk-method-name": cls.sdk_method_name, + } + def refresh_state(self): sr = self.current_sr channel = self.realtime_channel_name(sr.run_id, sr.uid) diff --git a/daras_ai_v2/custom_enum.py b/daras_ai_v2/custom_enum.py index b9aacb843..270d4834c 100644 --- a/daras_ai_v2/custom_enum.py +++ b/daras_ai_v2/custom_enum.py @@ -1,12 +1,38 @@ +import functools import typing from enum import Enum import typing_extensions + +def cached_classmethod(func: typing.Callable): + """ + This cache is a hack to get around a bug where + dynamic Enums with the same name cause a crash + when generating the OpenAPI spec. + """ + + @functools.wraps(func) + def wrapper(cls): + if not hasattr(cls, "_cached_classmethod"): + cls._cached_classmethod = {} + if id(func) not in cls._cached_classmethod: + cls._cached_classmethod[id(func)] = func(cls) + + return cls._cached_classmethod[id(func)] + + return wrapper + + T = typing.TypeVar("T", bound="GooeyEnum") class GooeyEnum(Enum): + @property + def api_value(self): + # api_value is usually the name + return self.name + @classmethod def db_choices(cls): return [(e.db_value, e.label) for e in cls] @@ -20,12 +46,23 @@ def from_db(cls, db_value) -> typing_extensions.Self: @classmethod @property - def api_choices(cls): - return typing.Literal[tuple(e.name for e in cls)] + @cached_classmethod + def api_enum(cls): + """ + Enum that is useful as a type in API requests. + + Maps `api_value`->`api_value` (default: `name`->`name`) + for all values. + + The title (same as the Enum class name) will be + used as the new Enum's title. This will be passed + on to the OpenAPI schema and the generated SDK. + """ + return Enum(cls.__name__, {e.api_value: e.api_value for e in cls}) @classmethod - def from_api(cls, name: str) -> typing_extensions.Self: + def from_api(cls, api_value: str) -> typing_extensions.Self: for e in cls: - if e.name == name: + if e.api_value == api_value: return e - raise ValueError(f"Invalid {cls.__name__} {name=}") + raise ValueError(f"Invalid {cls.__name__} {api_value=}") diff --git a/daras_ai_v2/embedding_model.py b/daras_ai_v2/embedding_model.py index 9f5f3ae1b..5f2927bcc 100644 --- a/daras_ai_v2/embedding_model.py +++ b/daras_ai_v2/embedding_model.py @@ -1,7 +1,6 @@ import hashlib import io import typing -from enum import Enum from functools import partial import numpy as np @@ -13,6 +12,7 @@ from jinja2.lexer import whitespace_re from loguru import logger +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.language_model import get_openai_client from daras_ai_v2.redis_cache import ( @@ -25,7 +25,7 @@ class EmbeddingModel(typing.NamedTuple): label: str -class EmbeddingModels(Enum): +class EmbeddingModels(EmbeddingModel, GooeyEnum): openai_3_large = EmbeddingModel( model_id=("openai-text-embedding-3-large-prod-ca-1", "text-embedding-3-large"), label="Text Embedding 3 Large (OpenAI)", @@ -65,12 +65,8 @@ class EmbeddingModels(Enum): ) @property - def model_id(self) -> typing.Iterable[str] | str: - return self.value.model_id - - @property - def label(self) -> str: - return self.value.label + def db_value(self): + return self.name @classmethod def get(cls, key, default=None): diff --git a/daras_ai_v2/image_segmentation.py b/daras_ai_v2/image_segmentation.py index 099832979..004264afc 100644 --- a/daras_ai_v2/image_segmentation.py +++ b/daras_ai_v2/image_segmentation.py @@ -1,14 +1,13 @@ -from enum import Enum - import requests +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.gpu_server import ( call_celery_task_outfile, ) -class ImageSegmentationModels(Enum): +class ImageSegmentationModels(str, GooeyEnum): dis = "Dichotomous Image Segmentation" u2net = "U²-Net" diff --git a/daras_ai_v2/img_model_settings_widgets.py b/daras_ai_v2/img_model_settings_widgets.py index 7f62fb878..3b9705479 100644 --- a/daras_ai_v2/img_model_settings_widgets.py +++ b/daras_ai_v2/img_model_settings_widgets.py @@ -2,11 +2,10 @@ from daras_ai_v2.enum_selector_widget import enum_selector, enum_multiselect from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, + TextToImageModels, InpaintingModels, - Img2ImgModels, + ImageToImageModels, ControlNetModels, - controlnet_model_explanations, Schedulers, ) @@ -37,10 +36,10 @@ def img_model_settings( negative_prompt_setting(selected_model) num_outputs_setting(selected_model) - if models_enum is not Img2ImgModels: + if models_enum is not ImageToImageModels: output_resolution_setting() - if models_enum is Text2ImgModels: + if models_enum is TextToImageModels: sd_2_upscaling_setting() col1, col2 = gui.columns(2) @@ -49,11 +48,11 @@ def img_model_settings( guidance_scale_setting(selected_model) with col2: - if models_enum is Img2ImgModels and not gui.session_state.get( + if models_enum is ImageToImageModels and not gui.session_state.get( "selected_controlnet_model" ): prompt_strength_setting(selected_model) - if selected_model == Img2ImgModels.instruct_pix2pix.name: + if selected_model == ImageToImageModels.instruct_pix2pix.name: instruct_pix2pix_settings() if show_scheduler: @@ -73,10 +72,10 @@ def model_selector( high_explanation: str = "At {high} the control nets will be applied tightly to the prompted visual, possibly overriding the prompt", ): controlnet_unsupported_models = [ - Img2ImgModels.instruct_pix2pix.name, - Img2ImgModels.dall_e.name, - Img2ImgModels.jack_qiao.name, - Img2ImgModels.sd_2.name, + ImageToImageModels.instruct_pix2pix.name, + ImageToImageModels.dall_e.name, + ImageToImageModels.jack_qiao.name, + ImageToImageModels.sd_2.name, ] col1, col2 = gui.columns(2) with col1: @@ -96,12 +95,12 @@ def model_selector( """ ) if ( - models_enum is Img2ImgModels + models_enum is ImageToImageModels and gui.session_state.get("selected_model") in controlnet_unsupported_models ): if "selected_controlnet_model" in gui.session_state: gui.session_state["selected_controlnet_model"] = None - elif models_enum is Img2ImgModels: + elif models_enum is ImageToImageModels: enum_multiselect( ControlNetModels, label=controlnet_explanation, @@ -130,9 +129,7 @@ def controlnet_settings( if not models: return - if extra_explanations is None: - extra_explanations = {} - explanations = controlnet_model_explanations | extra_explanations + extra_explanations = extra_explanations or {} state_values = gui.session_state.get("controlnet_conditioning_scale", []) new_values = [] @@ -157,7 +154,9 @@ def controlnet_settings( pass new_values.append( controlnet_weight_setting( - selected_controlnet_model=model, explanations=explanations, key=key + selected_controlnet_model=model, + extra_explanations=extra_explanations, + key=key, ), ) gui.session_state["controlnet_conditioning_scale"] = new_values @@ -166,13 +165,13 @@ def controlnet_settings( def controlnet_weight_setting( *, selected_controlnet_model: str, - explanations: dict[ControlNetModels, str], + extra_explanations: dict[ControlNetModels, str], key: str = "controlnet_conditioning_scale", ): model = ControlNetModels[selected_controlnet_model] return gui.slider( label=f""" - {explanations[model]}. + {extra_explanations.get(model, model.explanation)}. """, key=key, min_value=CONTROLNET_CONDITIONING_SCALE_RANGE[0], @@ -215,7 +214,7 @@ def quality_setting(selected_models=None): return if any( [ - selected_model in [Text2ImgModels.dall_e_3.name] + selected_model in [TextToImageModels.dall_e_3.name] for selected_model in selected_models ] ): @@ -375,8 +374,8 @@ def sd_2_upscaling_setting(): def scheduler_setting(selected_model: str = None): if selected_model in [ - Text2ImgModels.dall_e.name, - Text2ImgModels.jack_qiao, + TextToImageModels.dall_e.name, + TextToImageModels.jack_qiao, ]: return enum_selector( @@ -395,8 +394,8 @@ def scheduler_setting(selected_model: str = None): def guidance_scale_setting(selected_model: str = None): if selected_model in [ - Text2ImgModels.dall_e.name, - Text2ImgModels.jack_qiao, + TextToImageModels.dall_e.name, + TextToImageModels.jack_qiao, ]: return gui.slider( @@ -435,8 +434,8 @@ def instruct_pix2pix_settings(): def prompt_strength_setting(selected_model: str = None): if selected_model in [ - Img2ImgModels.dall_e.name, - Img2ImgModels.instruct_pix2pix.name, + ImageToImageModels.dall_e.name, + ImageToImageModels.instruct_pix2pix.name, ]: return @@ -458,7 +457,7 @@ def prompt_strength_setting(selected_model: str = None): def negative_prompt_setting(selected_model: str = None): - if selected_model in [Text2ImgModels.dall_e.name]: + if selected_model in [TextToImageModels.dall_e.name]: return gui.text_area( diff --git a/daras_ai_v2/language_model.py b/daras_ai_v2/language_model.py index 27f9c5786..6ab0be2e9 100644 --- a/daras_ai_v2/language_model.py +++ b/daras_ai_v2/language_model.py @@ -25,6 +25,7 @@ from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes from daras_ai_v2.asr import get_google_auth_session +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import raise_for_status, UserError from daras_ai_v2.gpu_server import call_celery_task from daras_ai_v2.text_splitter import ( @@ -72,7 +73,7 @@ class LLMSpec(typing.NamedTuple): supports_json: bool = False -class LargeLanguageModels(Enum): +class LargeLanguageModels(GooeyEnum): # https://platform.openai.com/docs/models/gpt-4o gpt_4_o = LLMSpec( label="GPT-4o (openai)", @@ -474,7 +475,9 @@ def get_entry_text(entry: ConversationEntry) -> str: ) -ResponseFormatType = typing.Literal["text", "json_object"] +class ResponseFormatType(str, GooeyEnum): + text = "text" + json_object = "json_object" def run_language_model( diff --git a/daras_ai_v2/lipsync_api.py b/daras_ai_v2/lipsync_api.py index 7aa81cc12..34f6c7f75 100644 --- a/daras_ai_v2/lipsync_api.py +++ b/daras_ai_v2/lipsync_api.py @@ -1,15 +1,15 @@ import typing -from enum import Enum from loguru import logger from pydantic import BaseModel, Field +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import UserError, GPUError from daras_ai_v2.gpu_server import call_celery_task_outfile_with_ret from daras_ai_v2.pydantic_validation import FieldHttpUrl -class LipsyncModel(Enum): +class LipsyncModels(GooeyEnum): Wav2Lip = "SD, Low-res (~480p), Fast (Rudrabha/Wav2Lip)" SadTalker = "HD, Hi-res (max 1080p), Slow (OpenTalker/SadTalker)" diff --git a/daras_ai_v2/lipsync_settings_widgets.py b/daras_ai_v2/lipsync_settings_widgets.py index 515be000b..b06032686 100644 --- a/daras_ai_v2/lipsync_settings_widgets.py +++ b/daras_ai_v2/lipsync_settings_widgets.py @@ -1,15 +1,15 @@ import gooey_gui as gui from daras_ai_v2.field_render import field_label_val -from daras_ai_v2.lipsync_api import LipsyncModel, SadTalkerSettings +from daras_ai_v2.lipsync_api import LipsyncModels, SadTalkerSettings def lipsync_settings(selected_model: str): match selected_model: - case LipsyncModel.Wav2Lip.name: + case LipsyncModels.Wav2Lip.name: wav2lip_settings() gui.session_state.pop("sadtalker_settings", None) - case LipsyncModel.SadTalker.name: + case LipsyncModels.SadTalker.name: settings = SadTalkerSettings.parse_obj( gui.session_state.setdefault( "sadtalker_settings", SadTalkerSettings().dict() diff --git a/daras_ai_v2/search_ref.py b/daras_ai_v2/search_ref.py index 0ed45386e..6ea91368e 100644 --- a/daras_ai_v2/search_ref.py +++ b/daras_ai_v2/search_ref.py @@ -1,10 +1,10 @@ import re import typing -from enum import Enum import jinja2 from typing_extensions import TypedDict +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import UserError from daras_ai_v2.scrollable_html_widget import scrollable_html @@ -16,7 +16,7 @@ class SearchReference(TypedDict): score: float -class CitationStyles(Enum): +class CitationStyles(GooeyEnum): number = "Numbers ( [1] [2] [3] ..)" title = "Source Title ( [Source 1] [Source 2] [Source 3] ..)" url = "Source URL ( [https://source1.com] [https://source2.com] [https://source3.com] ..)" diff --git a/daras_ai_v2/serp_search.py b/daras_ai_v2/serp_search.py index b95586461..7dea257de 100644 --- a/daras_ai_v2/serp_search.py +++ b/daras_ai_v2/serp_search.py @@ -20,7 +20,7 @@ def get_related_questions_from_serp_api( ) -> tuple[dict, list[str]]: data = call_serp_api( search_query, - search_type=SerpSearchType.SEARCH, + search_type=SerpSearchType.search, search_location=search_location, ) items = data.get("peopleAlsoAsk", []) or data.get("relatedSearches", []) @@ -66,10 +66,10 @@ def call_serp_api( search_location: SerpSearchLocation, ) -> dict: r = requests.post( - "https://google.serper.dev/" + search_type.value, + "https://google.serper.dev/" + search_type.api_value, json=dict( q=query, - gl=search_location.value, + gl=search_location.api_value, ), headers={"X-API-KEY": settings.SERPER_API_KEY}, ) diff --git a/daras_ai_v2/serp_search_locations.py b/daras_ai_v2/serp_search_locations.py index 2c26dc8a9..bcb39c053 100644 --- a/daras_ai_v2/serp_search_locations.py +++ b/daras_ai_v2/serp_search_locations.py @@ -1,8 +1,11 @@ -from django.db.models import TextChoices +import typing + from pydantic import BaseModel from pydantic import Field import gooey_gui as gui +from daras_ai_v2.custom_enum import GooeyEnum +from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.field_render import field_title_desc @@ -26,10 +29,10 @@ def serp_search_settings(): def serp_search_type_selectbox(key="serp_search_type"): - gui.selectbox( + enum_selector( + SerpSearchType, f"###### {field_title_desc(GoogleSearchMixin, key)}", - options=SerpSearchType, - format_func=lambda x: x.label, + use_selectbox=True, key=key, ) @@ -37,22 +40,31 @@ def serp_search_type_selectbox(key="serp_search_type"): def serp_search_location_selectbox(key="serp_search_location"): gui.selectbox( f"###### {field_title_desc(GoogleSearchMixin, key)}", - options=SerpSearchLocation, - format_func=lambda x: f"{x.label} ({x.value})", + options=[e.api_value for e in SerpSearchLocations], + format_func=lambda e: f"{SerpSearchLocations.from_api(e).label} ({e})", key=key, - value=SerpSearchLocation.UNITED_STATES, + value=SerpSearchLocations.UNITED_STATES.api_value, ) -class SerpSearchType(TextChoices): - SEARCH = "search", "🔎 Search" - IMAGES = "images", "📷 Images" - VIDEOS = "videos", "🎥 Videos" - PLACES = "places", "📍 Places" - NEWS = "news", "📰 News" +class SerpSearchType(GooeyEnum): + search = "🔎 Search" + images = "📷 Images" + videos = "🎥 Videos" + places = "📍 Places" + news = "📰 News" + + @property + def label(self): + return self.value + + +class SerpSearchLocation(typing.NamedTuple): + api_value: str + label: str -class SerpSearchLocation(TextChoices): +class SerpSearchLocations(SerpSearchLocation, GooeyEnum): AFGHANISTAN = "af", "Afghanistan" ALBANIA = "al", "Albania" ALGERIA = "dz", "Algeria" @@ -304,7 +316,7 @@ class SerpSearchLocation(TextChoices): class GoogleSearchLocationMixin(BaseModel): - serp_search_location: SerpSearchLocation | None = Field( + serp_search_location: SerpSearchLocations.api_enum | None = Field( title="Web Search Location", ) scaleserp_locations: list[str] | None = Field( @@ -313,7 +325,7 @@ class GoogleSearchLocationMixin(BaseModel): class GoogleSearchMixin(GoogleSearchLocationMixin, BaseModel): - serp_search_type: SerpSearchType | None = Field( + serp_search_type: SerpSearchType.api_enum | None = Field( title="Web Search Type", ) scaleserp_search_field: str | None = Field( diff --git a/daras_ai_v2/stable_diffusion.py b/daras_ai_v2/stable_diffusion.py index f59b044c3..a6ee14422 100644 --- a/daras_ai_v2/stable_diffusion.py +++ b/daras_ai_v2/stable_diffusion.py @@ -1,6 +1,5 @@ import io import typing -from enum import Enum import requests from PIL import Image @@ -13,6 +12,7 @@ resize_img_fit, get_downscale_factor, ) +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import ( raise_for_status, UserError, @@ -27,179 +27,249 @@ SD_IMG_MAX_SIZE = (768, 768) -class InpaintingModels(Enum): - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - runway_ml = "Stable Diffusion v1.5 (RunwayML)" - dall_e = "Dall-E (OpenAI)" +class InpaintingModel(typing.NamedTuple): + model_id: str | None + label: str - jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" + +class InpaintingModels(InpaintingModel, GooeyEnum): + sd_2 = InpaintingModel( + label="Stable Diffusion v2.1 (stability.ai)", + model_id="stabilityai/stable-diffusion-2-inpainting", + ) + runway_ml = InpaintingModel( + label="Stable Diffusion v1.5 (RunwayML)", + model_id="runwayml/stable-diffusion-inpainting", + ) + dall_e = InpaintingModel(label="Dall-E (OpenAI)", model_id="dall-e-2") + + jack_qiao = InpaintingModel( + label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None + ) @classmethod def _deprecated(cls): return {cls.jack_qiao} -inpaint_model_ids = { - InpaintingModels.sd_2: "stabilityai/stable-diffusion-2-inpainting", - InpaintingModels.runway_ml: "runwayml/stable-diffusion-inpainting", -} +class TextToImageModel(typing.NamedTuple): + model_id: str | None + label: str -class Text2ImgModels(Enum): +class TextToImageModels(TextToImageModel, GooeyEnum): # sd_1_4 = "SD v1.4 (RunwayML)" # Host this too? - dream_shaper = "DreamShaper (Lykon)" - dreamlike_2 = "Dreamlike Photoreal 2.0 (dreamlike.art)" - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" + dream_shaper = TextToImageModel( + label="DreamShaper (Lykon)", model_id="Lykon/DreamShaper" + ) + dreamlike_2 = TextToImageModel( + label="Dreamlike Photoreal 2.0 (dreamlike.art)", + model_id="dreamlike-art/dreamlike-photoreal-2.0", + ) + sd_2 = TextToImageModel( + label="Stable Diffusion v2.1 (stability.ai)", + model_id="stabilityai/stable-diffusion-2-1", + ) + sd_1_5 = TextToImageModel( + label="Stable Diffusion v1.5 (RunwayML)", + model_id="runwayml/stable-diffusion-v1-5", + ) - dall_e = "DALL·E 2 (OpenAI)" - dall_e_3 = "DALL·E 3 (OpenAI)" + dall_e = TextToImageModel(label="DALL·E 2 (OpenAI)", model_id="dall-e-2") + dall_e_3 = TextToImageModel(label="DALL·E 3 (OpenAI)", model_id="dall-e-3") - openjourney_2 = "Open Journey v2 beta (PromptHero)" - openjourney = "Open Journey (PromptHero)" - analog_diffusion = "Analog Diffusion (wavymulder)" - protogen_5_3 = "Protogen v5.3 (darkstorm2150)" + openjourney_2 = TextToImageModel( + label="Open Journey v2 beta (PromptHero)", model_id="prompthero/openjourney-v2" + ) + openjourney = TextToImageModel( + label="Open Journey (PromptHero)", model_id="prompthero/openjourney" + ) + analog_diffusion = TextToImageModel( + label="Analog Diffusion (wavymulder)", model_id="wavymulder/Analog-Diffusion" + ) + protogen_5_3 = TextToImageModel( + label="Protogen v5.3 (darkstorm2150)", + model_id="darkstorm2150/Protogen_v5.3_Official_Release", + ) - jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" - rodent_diffusion_1_5 = "Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)" - deepfloyd_if = "DeepFloyd IF [Deprecated] (stability.ai)" + jack_qiao = TextToImageModel( + label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None + ) + rodent_diffusion_1_5 = TextToImageModel( + label="Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)", model_id=None + ) + deepfloyd_if = TextToImageModel( + label="DeepFloyd IF [Deprecated] (stability.ai)", model_id=None + ) @classmethod def _deprecated(cls): return {cls.jack_qiao, cls.deepfloyd_if, cls.rodent_diffusion_1_5} -text2img_model_ids = { - Text2ImgModels.sd_1_5: "runwayml/stable-diffusion-v1-5", - Text2ImgModels.sd_2: "stabilityai/stable-diffusion-2-1", - Text2ImgModels.dream_shaper: "Lykon/DreamShaper", - Text2ImgModels.analog_diffusion: "wavymulder/Analog-Diffusion", - Text2ImgModels.openjourney: "prompthero/openjourney", - Text2ImgModels.openjourney_2: "prompthero/openjourney-v2", - Text2ImgModels.dreamlike_2: "dreamlike-art/dreamlike-photoreal-2.0", - Text2ImgModels.protogen_5_3: "darkstorm2150/Protogen_v5.3_Official_Release", -} -dall_e_model_ids = { - Text2ImgModels.dall_e: "dall-e-2", - Text2ImgModels.dall_e_3: "dall-e-3", -} +class ImageToImageModel(typing.NamedTuple): + model_id: str | None + label: str -class Img2ImgModels(Enum): - dream_shaper = "DreamShaper (Lykon)" - dreamlike_2 = "Dreamlike Photoreal 2.0 (dreamlike.art)" - sd_2 = "Stable Diffusion v2.1 (stability.ai)" - sd_1_5 = "Stable Diffusion v1.5 (RunwayML)" +class ImageToImageModels(ImageToImageModel, GooeyEnum): + dream_shaper = ImageToImageModel( + label="DreamShaper (Lykon)", model_id="Lykon/DreamShaper" + ) + dreamlike_2 = ImageToImageModel( + label="Dreamlike Photoreal 2.0 (dreamlike.art)", + model_id="dreamlike-art/dreamlike-photoreal-2.0", + ) + sd_2 = ImageToImageModel( + label="Stable Diffusion v2.1 (stability.ai)", + model_id="stabilityai/stable-diffusion-2-1", + ) + sd_1_5 = ImageToImageModel( + label="Stable Diffusion v1.5 (RunwayML)", + model_id="runwayml/stable-diffusion-v1-5", + ) - dall_e = "Dall-E (OpenAI)" + dall_e = ImageToImageModel(label="Dall-E (OpenAI)", model_id=None) - instruct_pix2pix = "✨ InstructPix2Pix (Tim Brooks)" - openjourney_2 = "Open Journey v2 beta (PromptHero) 🐢" - openjourney = "Open Journey (PromptHero) 🐢" - analog_diffusion = "Analog Diffusion (wavymulder) 🐢" - protogen_5_3 = "Protogen v5.3 (darkstorm2150) 🐢" + instruct_pix2pix = ImageToImageModel( + label="✨ InstructPix2Pix (Tim Brooks)", model_id=None + ) + openjourney_2 = ImageToImageModel( + label="Open Journey v2 beta (PromptHero) 🐢", + model_id="prompthero/openjourney-v2", + ) + openjourney = ImageToImageModel( + label="Open Journey (PromptHero) 🐢", model_id="prompthero/openjourney" + ) + analog_diffusion = ImageToImageModel( + label="Analog Diffusion (wavymulder) 🐢", model_id="wavymulder/Analog-Diffusion" + ) + protogen_5_3 = ImageToImageModel( + label="Protogen v5.3 (darkstorm2150) 🐢", + model_id="darkstorm2150/Protogen_v5.3_Official_Release", + ) - jack_qiao = "Stable Diffusion v1.4 [Deprecated] (Jack Qiao)" - rodent_diffusion_1_5 = "Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)" + jack_qiao = ImageToImageModel( + label="Stable Diffusion v1.4 [Deprecated] (Jack Qiao)", model_id=None + ) + rodent_diffusion_1_5 = ImageToImageModel( + label="Rodent Diffusion 1.5 [Deprecated] (NerdyRodent)", model_id=None + ) @classmethod def _deprecated(cls): return {cls.jack_qiao, cls.rodent_diffusion_1_5} -img2img_model_ids = { - Img2ImgModels.sd_2: "stabilityai/stable-diffusion-2-1", - Img2ImgModels.sd_1_5: "runwayml/stable-diffusion-v1-5", - Img2ImgModels.dream_shaper: "Lykon/DreamShaper", - Img2ImgModels.openjourney: "prompthero/openjourney", - Img2ImgModels.openjourney_2: "prompthero/openjourney-v2", - Img2ImgModels.analog_diffusion: "wavymulder/Analog-Diffusion", - Img2ImgModels.protogen_5_3: "darkstorm2150/Protogen_v5.3_Official_Release", - Img2ImgModels.dreamlike_2: "dreamlike-art/dreamlike-photoreal-2.0", -} - - -class ControlNetModels(Enum): - sd_controlnet_canny = "Canny" - sd_controlnet_depth = "Depth" - sd_controlnet_hed = "HED Boundary" - sd_controlnet_mlsd = "M-LSD Straight Line" - sd_controlnet_normal = "Normal Map" - sd_controlnet_openpose = "Human Pose" - sd_controlnet_scribble = "Scribble" - sd_controlnet_seg = "Image Segmentation" - sd_controlnet_tile = "Tiling" - sd_controlnet_brightness = "Brightness" - control_v1p_sd15_qrcode_monster_v2 = "QR Monster V2" - - -controlnet_model_explanations = { - ControlNetModels.sd_controlnet_canny: "Canny edge detection", - ControlNetModels.sd_controlnet_depth: "Depth estimation", - ControlNetModels.sd_controlnet_hed: "HED edge detection", - ControlNetModels.sd_controlnet_mlsd: "M-LSD straight line detection", - ControlNetModels.sd_controlnet_normal: "Normal map estimation", - ControlNetModels.sd_controlnet_openpose: "Human pose estimation", - ControlNetModels.sd_controlnet_scribble: "Scribble", - ControlNetModels.sd_controlnet_seg: "Image segmentation", - ControlNetModels.sd_controlnet_tile: "Tiling: to preserve small details", - ControlNetModels.sd_controlnet_brightness: "Brightness: to increase contrast naturally", - ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose", -} - -controlnet_model_ids = { - ControlNetModels.sd_controlnet_canny: "lllyasviel/sd-controlnet-canny", - ControlNetModels.sd_controlnet_depth: "lllyasviel/sd-controlnet-depth", - ControlNetModels.sd_controlnet_hed: "lllyasviel/sd-controlnet-hed", - ControlNetModels.sd_controlnet_mlsd: "lllyasviel/sd-controlnet-mlsd", - ControlNetModels.sd_controlnet_normal: "lllyasviel/sd-controlnet-normal", - ControlNetModels.sd_controlnet_openpose: "lllyasviel/sd-controlnet-openpose", - ControlNetModels.sd_controlnet_scribble: "lllyasviel/sd-controlnet-scribble", - ControlNetModels.sd_controlnet_seg: "lllyasviel/sd-controlnet-seg", - ControlNetModels.sd_controlnet_tile: "lllyasviel/control_v11f1e_sd15_tile", - ControlNetModels.sd_controlnet_brightness: "ioclab/control_v1p_sd15_brightness", - ControlNetModels.control_v1p_sd15_qrcode_monster_v2: "monster-labs/control_v1p_sd15_qrcode_monster/v2", -} - - -class Schedulers(models.TextChoices): - singlestep_dpm_solver = ( - "DPM", - "DPMSolverSinglestepScheduler", - ) - multistep_dpm_solver = "DPM Multistep", "DPMSolverMultistepScheduler" - dpm_sde = ( - "DPM SDE", - "DPMSolverSDEScheduler", - ) - dpm_discrete = ( - "DPM Discrete", - "KDPM2DiscreteScheduler", - ) - dpm_discrete_ancestral = ( - "DPM Anscetral", - "KDPM2AncestralDiscreteScheduler", - ) - unipc = "UniPC", "UniPCMultistepScheduler" - lms_discrete = ( - "LMS", - "LMSDiscreteScheduler", - ) - heun = ( - "Heun", - "HeunDiscreteScheduler", - ) - euler = "Euler", "EulerDiscreteScheduler" - euler_ancestral = ( - "Euler ancestral", - "EulerAncestralDiscreteScheduler", - ) - pndm = "PNDM", "PNDMScheduler" - ddpm = "DDPM", "DDPMScheduler" - ddim = "DDIM", "DDIMScheduler" - deis = ( - "DEIS", - "DEISMultistepScheduler", +class ControlNetModel(typing.NamedTuple): + label: str + model_id: str + explanation: str + + +class ControlNetModels(ControlNetModel, GooeyEnum): + sd_controlnet_canny = ControlNetModel( + label="Canny", + explanation="Canny edge detection", + model_id="lllyasviel/sd-controlnet-canny", + ) + sd_controlnet_depth = ControlNetModel( + label="Depth", + explanation="Depth estimation", + model_id="lllyasviel/sd-controlnet-depth", + ) + sd_controlnet_hed = ControlNetModel( + label="HED Boundary", + explanation="HED edge detection", + model_id="lllyasviel/sd-controlnet-hed", + ) + sd_controlnet_mlsd = ControlNetModel( + label="M-LSD Straight Line", + explanation="M-LSD straight line detection", + model_id="lllyasviel/sd-controlnet-mlsd", + ) + sd_controlnet_normal = ControlNetModel( + label="Normal Map", + explanation="Normal map estimation", + model_id="lllyasviel/sd-controlnet-normal", + ) + sd_controlnet_openpose = ControlNetModel( + label="Human Pose", + explanation="Human pose estimation", + model_id="lllyasviel/sd-controlnet-openpose", + ) + sd_controlnet_scribble = ControlNetModel( + label="Scribble", + explanation="Scribble", + model_id="lllyasviel/sd-controlnet-scribble", + ) + sd_controlnet_seg = ControlNetModel( + label="Image Segmentation", + explanation="Image segmentation", + model_id="lllyasviel/sd-controlnet-seg", + ) + sd_controlnet_tile = ControlNetModel( + label="Tiling", + explanation="Tiling: to preserve small details", + model_id="lllyasviel/control_v11f1e_sd15_tile", + ) + sd_controlnet_brightness = ControlNetModel( + label="Brightness", + explanation="Brightness: to increase contrast naturally", + model_id="ioclab/control_v1p_sd15_brightness", + ) + control_v1p_sd15_qrcode_monster_v2 = ControlNetModel( + label="QR Monster V2", + explanation="QR Monster: make beautiful QR codes that still scan with a controlnet specifically trained for this purpose", + model_id="monster-labs/control_v1p_sd15_qrcode_monster/v2", + ) + + +class Scheduler(typing.NamedTuple): + label: str + model_id: str + + +class Schedulers(Scheduler, GooeyEnum): + singlestep_dpm_solver = Scheduler( + label="DPM", + model_id="DPMSolverSinglestepScheduler", + ) + multistep_dpm_solver = Scheduler( + label="DPM Multistep", model_id="DPMSolverMultistepScheduler" + ) + dpm_sde = Scheduler( + label="DPM SDE", + model_id="DPMSolverSDEScheduler", + ) + dpm_discrete = Scheduler( + label="DPM Discrete", + model_id="KDPM2DiscreteScheduler", + ) + dpm_discrete_ancestral = Scheduler( + label="DPM Anscetral", + model_id="KDPM2AncestralDiscreteScheduler", + ) + unipc = Scheduler(label="UniPC", model_id="UniPCMultistepScheduler") + lms_discrete = Scheduler( + label="LMS", + model_id="LMSDiscreteScheduler", + ) + heun = Scheduler( + label="Heun", + model_id="HeunDiscreteScheduler", + ) + euler = Scheduler("Euler", "EulerDiscreteScheduler") + euler_ancestral = Scheduler( + label="Euler ancestral", + model_id="EulerAncestralDiscreteScheduler", + ) + pndm = Scheduler(label="PNDM", model_id="PNDMScheduler") + ddpm = Scheduler(label="DDPM", model_id="DDPMScheduler") + ddim = Scheduler(label="DDIM", model_id="DDIMScheduler") + deis = Scheduler( + label="DEIS", + model_id="DEISMultistepScheduler", ) @@ -282,18 +352,18 @@ def text2img( dall_e_3_quality: str | None = None, dall_e_3_style: str | None = None, ): - if selected_model != Text2ImgModels.dall_e_3.name: + if selected_model != TextToImageModels.dall_e_3.name: _resolution_check(width, height, max_size=(1024, 1024)) match selected_model: - case Text2ImgModels.dall_e_3.name: + case TextToImageModels.dall_e_3.name: from openai import OpenAI client = OpenAI() width, height = _get_dall_e_3_img_size(width, height) with capture_openai_content_policy_violation(): response = client.images.generate( - model=dall_e_model_ids[Text2ImgModels[selected_model]], + model=TextToImageModels[selected_model].model_id, n=1, # num_outputs, not supported yet prompt=prompt, response_format="b64_json", @@ -302,7 +372,7 @@ def text2img( size=f"{width}x{height}", ) out_imgs = [b64_img_decode(part.b64_json) for part in response.data] - case Text2ImgModels.dall_e.name: + case TextToImageModels.dall_e.name: from openai import OpenAI edge = _get_dall_e_img_size(width, height) @@ -320,8 +390,8 @@ def text2img( return call_sd_multi( "diffusion.text2img", pipeline={ - "model_id": text2img_model_ids[Text2ImgModels[selected_model]], - "scheduler": Schedulers[scheduler].label if scheduler else None, + "model_id": TextToImageModels[selected_model].model_id, + "scheduler": Schedulers[scheduler].model_id if scheduler else None, "disable_safety_checker": True, "seed": seed, }, @@ -382,7 +452,7 @@ def img2img( _resolution_check(width, height) match selected_model: - case Img2ImgModels.dall_e.name: + case ImageToImageModels.dall_e.name: from openai import OpenAI edge = _get_dall_e_img_size(width, height) @@ -413,7 +483,7 @@ def img2img( return call_sd_multi( "diffusion.img2img", pipeline={ - "model_id": img2img_model_ids[Img2ImgModels[selected_model]], + "model_id": ImageToImageModels[selected_model].model_id, # "scheduler": "UniPCMultistepScheduler", "disable_safety_checker": True, "seed": seed, @@ -456,15 +526,16 @@ def controlnet( return call_sd_multi( "diffusion.controlnet", pipeline={ - "model_id": text2img_model_ids[Text2ImgModels[selected_model]], + "model_id": TextToImageModels[selected_model].model_id, "seed": seed, "scheduler": ( - Schedulers[scheduler].label if scheduler else "UniPCMultistepScheduler" + Schedulers[scheduler].model_id + if scheduler + else Schedulers.unipc.model_id ), "disable_safety_checker": True, "controlnet_model_id": [ - controlnet_model_ids[ControlNetModels[model]] - for model in selected_controlnet_model + ControlNetModels[model].model_id for model in selected_controlnet_model ], }, inputs={ @@ -482,13 +553,13 @@ def controlnet( def add_prompt_prefix(prompt: str, selected_model: str) -> str: match selected_model: - case Text2ImgModels.openjourney.name: + case TextToImageModels.openjourney.name: prompt = "mdjrny-v4 style " + prompt - case Text2ImgModels.analog_diffusion.name: + case TextToImageModels.analog_diffusion.name: prompt = "analog style " + prompt - case Text2ImgModels.protogen_5_3.name: + case TextToImageModels.protogen_5_3.name: prompt = "modelshoot style " + prompt - case Text2ImgModels.dreamlike_2.name: + case TextToImageModels.dreamlike_2.name: prompt = "photo, " + prompt return prompt @@ -535,9 +606,9 @@ def inpainting( out_imgs_urls = call_sd_multi( "diffusion.inpaint", pipeline={ - "model_id": inpaint_model_ids[InpaintingModels[selected_model]], + "model_id": InpaintingModels[selected_model].model_id, "seed": seed, - # "scheduler": Schedulers[scheduler].label + # "scheduler": Schedulers[scheduler].model_id # if scheduler # else "UniPCMultistepScheduler", "disable_safety_checker": True, diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 5214d5cc5..65b5b56ea 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -46,7 +46,7 @@ class OpenAI_TTS_Voices(GooeyEnum): shimmer = "shimmer" -class TextToSpeechProviders(Enum): +class TextToSpeechProviders(GooeyEnum): GOOGLE_TTS = "Google Text-to-Speech" ELEVEN_LABS = "Eleven Labs" UBERDUCK = "Uberduck.ai" diff --git a/daras_ai_v2/upscaler_models.py b/daras_ai_v2/upscaler_models.py index 11d8ff225..f5d46a40a 100644 --- a/daras_ai_v2/upscaler_models.py +++ b/daras_ai_v2/upscaler_models.py @@ -1,11 +1,11 @@ import typing -from enum import Enum from pathlib import Path import replicate import requests from daras_ai.image_input import upload_file_from_bytes +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.exceptions import UserError from daras_ai_v2.gpu_server import call_celery_task_outfile from daras_ai_v2.pydantic_validation import FieldHttpUrl @@ -19,7 +19,7 @@ class UpscalerModel(typing.NamedTuple): is_bg_model: bool = False -class UpscalerModels(UpscalerModel, Enum): +class UpscalerModels(UpscalerModel, GooeyEnum): gfpgan_1_4 = UpscalerModel( model_id="GFPGANv1.4", label="GFPGAN v1.4 (Tencent ARC)", diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 36de65f29..ba0c4ce59 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -71,8 +71,13 @@ class DocSearchRequest(BaseModel): + class Config: + use_enum_values = True + search_query: str - keyword_query: str | list[str] | None + keyword_query: str | list[str] | None = Field( + **{"x-fern-type-name": "KeywordQuery"} + ) documents: list[str] | None @@ -82,7 +87,7 @@ class DocSearchRequest(BaseModel): doc_extract_url: str | None - embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + embedding_model: EmbeddingModels.api_enum | None dense_weight: float | None = Field( ge=0.0, le=1.0, diff --git a/embeddings/models.py b/embeddings/models.py index f30b1a0bf..f5f350bbc 100644 --- a/embeddings/models.py +++ b/embeddings/models.py @@ -28,7 +28,7 @@ class EmbeddedFile(models.Model): selected_asr_model = models.CharField(max_length=100, blank=True) embedding_model = models.CharField( max_length=100, - choices=[(model.name, model.label) for model in EmbeddingModels], + choices=EmbeddingModels.db_choices(), default=EmbeddingModels.openai_3_large.name, ) diff --git a/functions/models.py b/functions/models.py index be1be7a5e..711805d64 100644 --- a/functions/models.py +++ b/functions/models.py @@ -22,7 +22,7 @@ class RecipeFunction(BaseModel): title="URL", description="The URL of the [function](https://gooey.ai/functions) to call.", ) - trigger: FunctionTrigger.api_choices = Field( + trigger: FunctionTrigger.api_enum = Field( title="Trigger", description="When to run this function. `pre` runs before the recipe, `post` runs after the recipe.", ) @@ -30,7 +30,7 @@ class RecipeFunction(BaseModel): class CalledFunctionResponse(BaseModel): url: str - trigger: FunctionTrigger.api_choices + trigger: FunctionTrigger.api_enum return_value: typing.Any @classmethod diff --git a/recipes/BulkEval.py b/recipes/BulkEval.py index 7f92a53e1..62aabe8c3 100644 --- a/recipes/BulkEval.py +++ b/recipes/BulkEval.py @@ -139,6 +139,7 @@ class BulkEvalPage(BasePage): title = "Evaluator" workflow = Workflow.BULK_EVAL slug_versions = ["bulk-eval", "eval"] + sdk_method_name = "eval" explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aad314f0-9a97-11ee-8318-02420a0001c7/W.I.9.png.png" @@ -185,9 +186,7 @@ class RequestModelBase(BasePage.RequestModel): """, ) - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/BulkRunner.py b/recipes/BulkRunner.py index 7ad9e67e9..7ab4e4358 100644 --- a/recipes/BulkRunner.py +++ b/recipes/BulkRunner.py @@ -43,6 +43,8 @@ class BulkRunnerPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/87f35df4-88d7-11ee-aac9-02420a00016b/Bulk%20Runner.png.png" workflow = Workflow.BULK_RUNNER slug_versions = ["bulk-runner", "bulk"] + sdk_method_name = "bulkRun" + price = 1 class RequestModel(BasePage.RequestModel): diff --git a/recipes/ChyronPlant.py b/recipes/ChyronPlant.py index fef4714b9..0777b93c7 100644 --- a/recipes/ChyronPlant.py +++ b/recipes/ChyronPlant.py @@ -1,4 +1,5 @@ import gooey_gui as gui +from gooey_gui.components import typing from pydantic import BaseModel from bots.models import Workflow @@ -12,6 +13,7 @@ class ChyronPlantPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.CHYRON_PLANT slug_versions = ["ChyronPlant"] + sdk_method_name = "" class RequestModel(BasePage.RequestModel): midi_notes: str @@ -30,6 +32,10 @@ class ResponseModel(BaseModel): midi_translation: str chyron_output: str + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return {"x-fern-ignore": True} + def render_form_v2(self): gui.text_input( """ diff --git a/recipes/CompareLLM.py b/recipes/CompareLLM.py index 513421146..55d317176 100644 --- a/recipes/CompareLLM.py +++ b/recipes/CompareLLM.py @@ -30,6 +30,7 @@ class CompareLLMPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ae42015e-88d7-11ee-aac9-02420a00016b/Compare%20LLMs.png.png" workflow = Workflow.COMPARE_LLM slug_versions = ["CompareLLM", "llm", "compare-large-language-models"] + sdk_method_name = "llm" functions_in_settings = False @@ -43,17 +44,16 @@ class CompareLLMPage(BasePage): class RequestModelBase(BasePage.RequestModel): input_prompt: str | None - selected_models: ( - list[typing.Literal[tuple(e.name for e in LargeLanguageModels)]] | None - ) + selected_models: list[LargeLanguageModels.api_enum] | None class RequestModel(LanguageModelSettings, RequestModelBase): pass class ResponseModel(BaseModel): - output_text: dict[ - typing.Literal[tuple(e.name for e in LargeLanguageModels)], list[str] - ] + class Config: + use_enum_values = True + + output_text: dict[LargeLanguageModels.api_enum, list[str]] def preview_image(self, state: dict) -> str | None: return DEFAULT_COMPARE_LM_META_IMG diff --git a/recipes/CompareText2Img.py b/recipes/CompareText2Img.py index f41a46170..ee0e985fb 100644 --- a/recipes/CompareText2Img.py +++ b/recipes/CompareText2Img.py @@ -20,7 +20,7 @@ from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, + TextToImageModels, text2img, instruct_pix2pix, sd_upscale, @@ -39,6 +39,7 @@ class CompareText2ImgPage(BasePage): "text2img", "compare-ai-image-generators", ] + sdk_method_name = "textToImage" sane_defaults = { "guidance_scale": 7.5, @@ -65,19 +66,14 @@ class RequestModel(BasePage.RequestModel): seed: int | None sd_2_upscaling: bool | None - selected_models: ( - list[typing.Literal[tuple(e.name for e in Text2ImgModels)]] | None - ) - scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None + selected_models: list[TextToImageModels.api_enum] | None + scheduler: Schedulers.api_enum | None edit_instruction: str | None image_guidance_scale: float | None class ResponseModel(BaseModel): - output_images: dict[ - typing.Literal[tuple(e.name for e in Text2ImgModels)], - list[FieldHttpUrl], - ] + output_images: dict[TextToImageModels.api_enum, list[FieldHttpUrl]] @classmethod def get_example_preferred_fields(cls, state: dict) -> list[str]: @@ -124,7 +120,7 @@ def render_form_v2(self): """ ) enum_multiselect( - Text2ImgModels, + TextToImageModels, key="selected_models", ) @@ -192,7 +188,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["output_images"] = output_images = {} for selected_model in request.selected_models: - yield f"Running {Text2ImgModels[selected_model].value}..." + yield f"Running {TextToImageModels[selected_model].label}..." output_images[selected_model] = text2img( selected_model=selected_model, @@ -253,7 +249,7 @@ def _render_outputs(self, state): output_images: dict = state.get("output_images", {}).get(key, []) for img in output_images: gui.image( - img, caption=Text2ImgModels[key].value, show_download_button=True + img, caption=TextToImageModels[key].label, show_download_button=True ) def preview_description(self, state: dict) -> str: @@ -264,9 +260,9 @@ def get_raw_price(self, state: dict) -> int: total = 0 for name in selected_models: match name: - case Text2ImgModels.deepfloyd_if.name: + case TextToImageModels.deepfloyd_if.name: total += 5 - case Text2ImgModels.dall_e.name | Text2ImgModels.dall_e_3.name: + case TextToImageModels.dall_e.name | TextToImageModels.dall_e_3.name: total += 15 case _: total += 2 diff --git a/recipes/CompareUpscaler.py b/recipes/CompareUpscaler.py index 4684ab309..e7032095c 100644 --- a/recipes/CompareUpscaler.py +++ b/recipes/CompareUpscaler.py @@ -20,6 +20,7 @@ class CompareUpscalerPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/64393e0c-88db-11ee-b428-02420a000168/AI%20Image%20Upscaler.png.png" workflow = Workflow.COMPARE_UPSCALER slug_versions = ["compare-ai-upscalers"] + sdk_method_name = "upscale" class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl | None = Field(None, description="Input Image") @@ -29,21 +30,22 @@ class RequestModel(BasePage.RequestModel): description="The final upsampling scale of the image", ge=1, le=4 ) - selected_models: ( - list[typing.Literal[tuple(e.name for e in UpscalerModels)]] | None - ) + selected_models: list[UpscalerModels.api_enum] | None selected_bg_model: ( typing.Literal[tuple(e.name for e in UpscalerModels if e.is_bg_model)] | None + ) = Field( + title="Selected Background Model", + **{"x-fern-type-name": "BackgroundUpscalerModels"}, ) class ResponseModel(BaseModel): - output_images: dict[ - typing.Literal[tuple(e.name for e in UpscalerModels)], FieldHttpUrl - ] = Field({}, description="Output Images") - output_videos: dict[ - typing.Literal[tuple(e.name for e in UpscalerModels)], FieldHttpUrl - ] = Field({}, description="Output Videos") + output_images: dict[UpscalerModels.api_enum, FieldHttpUrl] = Field( + default_factory=dict, description="Output Images" + ) + output_videos: dict[UpscalerModels.api_enum, FieldHttpUrl] = Field( + default_factory=dict, description="Output Videos" + ) def validate_form_v2(self): assert gui.session_state.get( @@ -69,6 +71,7 @@ def run_v2( for selected_model in request.selected_models: model = UpscalerModels[selected_model] yield f"Running {model.label}..." + print(f"{request.input_image=}, {request.input_video=}") if request.input_image: response.output_images[selected_model] = run_upscaler_model( selected_model=model, diff --git a/recipes/DeforumSD.py b/recipes/DeforumSD.py index ef4b186b5..eba951bf4 100644 --- a/recipes/DeforumSD.py +++ b/recipes/DeforumSD.py @@ -1,8 +1,8 @@ import typing import uuid +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.pydantic_validation import FieldHttpUrl -from django.db.models import TextChoices from pydantic import BaseModel from typing_extensions import TypedDict @@ -18,17 +18,28 @@ DEFAULT_DEFORUMSD_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/7dc25196-93fe-11ee-9e3a-02420a0001ce/AI%20Animation%20generator.jpg.png" -class AnimationModels(TextChoices): - protogen_2_2 = ("Protogen_V2.2.ckpt", "Protogen V2.2 (darkstorm2150)") - epicdream = ("epicdream.safetensors", "epiCDream (epinikion)") +class AnimationModel(typing.NamedTuple): + model_id: str + label: str -class _AnimationPrompt(TypedDict): +class AnimationModels(AnimationModel, GooeyEnum): + protogen_2_2 = AnimationModel( + model_id="Protogen_V2.2.ckpt", + label="Protogen V2.2 (darkstorm2150)", + ) + epicdream = AnimationModel( + model_id="epicdream.safetensors", + label="epiCDream (epinikion)", + ) + + +class AnimationPrompt(TypedDict): frame: str prompt: str -AnimationPrompts = list[_AnimationPrompt] +AnimationPrompts = list[AnimationPrompt] CREDITS_PER_FRAME = 1.5 MODEL_ESTIMATED_TIME_PER_FRAME = 2.4 # seconds @@ -166,6 +177,7 @@ class DeforumSDPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/media/users/kxmNIYAOJbfOURxHBKNCWeUSKiP2/dd88c110-88d6-11ee-9b4f-2b58bd50e819/animation.gif" workflow = Workflow.DEFORUM_SD slug_versions = ["DeforumSD", "animation-generator"] + sdk_method_name = "animate" sane_defaults = dict( zoom="0: (1.004)", @@ -185,7 +197,7 @@ class RequestModel(BasePage.RequestModel): animation_prompts: AnimationPrompts max_frames: int | None - selected_model: typing.Literal[tuple(e.name for e in AnimationModels)] | None + selected_model: AnimationModels.api_enum | None animation_mode: str | None zoom: str | None @@ -459,7 +471,7 @@ def run(self, state: dict): state["output_video"] = call_celery_task_outfile( "deforum", pipeline=dict( - model_id=AnimationModels[request.selected_model].value, + model_id=AnimationModels[request.selected_model].model_id, seed=request.seed, ), inputs=dict( diff --git a/recipes/DocExtract.py b/recipes/DocExtract.py index 323e3eab9..9a41f4f19 100644 --- a/recipes/DocExtract.py +++ b/recipes/DocExtract.py @@ -86,6 +86,7 @@ class DocExtractPage(BasePage): "youtube-bot", "doc-extract", ] + sdk_method_name = "synthesizeData" price = 500 class RequestModelBase(BasePage.RequestModel): @@ -93,7 +94,7 @@ class RequestModelBase(BasePage.RequestModel): sheet_url: FieldHttpUrl | None - selected_asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None + selected_asr_model: AsrModels.api_enum | None # language: str | None google_translate_target: str | None glossary_document: FieldHttpUrl | None = Field( @@ -104,9 +105,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/DocSearch.py b/recipes/DocSearch.py index d0c49f3c7..140ea796e 100644 --- a/recipes/DocSearch.py +++ b/recipes/DocSearch.py @@ -2,7 +2,7 @@ import typing from furl import furl -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -52,6 +52,7 @@ class DocSearchPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/cbbb4dc6-88d7-11ee-bf6c-02420a000166/Search%20your%20docs%20with%20gpt.png.png" workflow = Workflow.DOC_SEARCH slug_versions = ["doc-search"] + sdk_method_name = "rag" sane_defaults = { "sampling_temperature": 0.1, @@ -71,11 +72,9 @@ class RequestModelBase(DocSearchRequest, BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None - citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None + citation_style: CitationStyles.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/DocSummary.py b/recipes/DocSummary.py index 18412e197..1711de053 100644 --- a/recipes/DocSummary.py +++ b/recipes/DocSummary.py @@ -1,8 +1,8 @@ import typing -from enum import Enum +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.pydantic_validation import FieldHttpUrl -from pydantic import BaseModel, Field +from pydantic import BaseModel import gooey_gui as gui from bots.models import Workflow @@ -16,7 +16,6 @@ LargeLanguageModels, run_language_model, calc_gpt_tokens, - ResponseFormatType, ) from daras_ai_v2.language_model_settings_widgets import ( language_model_settings, @@ -38,7 +37,7 @@ DEFAULT_DOC_SUMMARY_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f35796d2-93fe-11ee-b86c-02420a0001c7/Summarize%20with%20GPT.jpg.png" -class CombineDocumentsChains(Enum): +class CombineDocumentsChains(GooeyEnum): map_reduce = "Map Reduce" # refine = "Refine" # stuff = "Stuffing (Only works for small documents)" @@ -49,6 +48,7 @@ class DocSummaryPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1f858a7a-88d8-11ee-a658-02420a000163/Summarize%20your%20docs%20with%20gpt.png.png" workflow = Workflow.DOC_SUMMARY slug_versions = ["doc-summary"] + sdk_method_name = "docSummary" price = 225 @@ -68,13 +68,11 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None merge_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None - chain_type: typing.Literal[tuple(e.name for e in CombineDocumentsChains)] | None + chain_type: CombineDocumentsChains.api_enum | None - selected_asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None + selected_asr_model: AsrModels.api_enum | None google_translate_target: str | None class RequestModel(LanguageModelSettings, RequestModelBase): diff --git a/recipes/EmailFaceInpainting.py b/recipes/EmailFaceInpainting.py index 3ebc161a9..0fcf39835 100644 --- a/recipes/EmailFaceInpainting.py +++ b/recipes/EmailFaceInpainting.py @@ -27,6 +27,7 @@ class EmailFaceInpaintingPage(FaceInpaintingPage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ec0df5aa-9521-11ee-93d3-02420a0001e5/Email%20Profile%20Lookup.png.png" workflow = Workflow.EMAIL_FACE_INPAINTING slug_versions = ["EmailFaceInpainting", "ai-image-from-email-lookup"] + sdk_method_name = "imageFromEmail" sane_defaults = { "num_outputs": 1, @@ -49,7 +50,7 @@ class RequestModel(BasePage.RequestModel): face_pos_x: float | None face_pos_y: float | None - selected_model: typing.Literal[tuple(e.name for e in InpaintingModels)] | None + selected_model: InpaintingModels.api_enum | None negative_prompt: str | None diff --git a/recipes/FaceInpainting.py b/recipes/FaceInpainting.py index 8770740fb..a28041d37 100644 --- a/recipes/FaceInpainting.py +++ b/recipes/FaceInpainting.py @@ -33,6 +33,7 @@ class FaceInpaintingPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/10c2ce06-88da-11ee-b428-02420a000168/ai%20image%20with%20a%20face.png.png" workflow = Workflow.FACE_INPAINTING slug_versions = ["FaceInpainting", "face-in-ai-generated-photo"] + sdk_method_name = "portrait" sane_defaults = { "num_outputs": 1, @@ -52,7 +53,7 @@ class RequestModel(BasePage.RequestModel): face_pos_x: float | None face_pos_y: float | None - selected_model: typing.Literal[tuple(e.name for e in InpaintingModels)] | None + selected_model: InpaintingModels.api_enum | None negative_prompt: str | None diff --git a/recipes/Functions.py b/recipes/Functions.py index 81b99c946..434effd27 100644 --- a/recipes/Functions.py +++ b/recipes/Functions.py @@ -22,6 +22,7 @@ class FunctionsPage(BasePage): title = "Functions" workflow = Workflow.FUNCTIONS slug_versions = ["functions", "tools", "function", "fn", "functions"] + sdk_method_name = "functions" show_settings = False price = 1 diff --git a/recipes/GoogleGPT.py b/recipes/GoogleGPT.py index 611483287..93d7a0ed2 100644 --- a/recipes/GoogleGPT.py +++ b/recipes/GoogleGPT.py @@ -1,7 +1,7 @@ import typing from furl import furl -from pydantic import BaseModel, Field +from pydantic import BaseModel import gooey_gui as gui from bots.models import Workflow @@ -11,11 +11,7 @@ doc_search_advanced_settings, ) from daras_ai_v2.embedding_model import EmbeddingModels -from daras_ai_v2.language_model import ( - run_language_model, - LargeLanguageModels, - ResponseFormatType, -) +from daras_ai_v2.language_model import run_language_model, LargeLanguageModels from daras_ai_v2.language_model_settings_widgets import ( language_model_settings, language_model_selector, @@ -31,9 +27,9 @@ from daras_ai_v2.serp_search import get_links_from_serp_api from daras_ai_v2.serp_search_locations import ( GoogleSearchMixin, - serp_search_settings, - SerpSearchLocation, + SerpSearchLocations, SerpSearchType, + serp_search_settings, ) from daras_ai_v2.vector_search import render_sources_widget from recipes.DocSearch import ( @@ -51,6 +47,7 @@ class GoogleGPTPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/28649544-9406-11ee-bba3-02420a0001cc/Websearch%20GPT%20option%202.png.png" workflow = Workflow.GOOGLE_GPT slug_versions = ["google-gpt"] + sdk_method_name = "webSearchLLM" price = 175 @@ -59,8 +56,8 @@ class GoogleGPTPage(BasePage): keywords="outdoor rugs,8x10 rugs,rug sizes,checkered rugs,5x7 rugs", title="Ruggable", company_url="https://ruggable.com", - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, enable_html=False, selected_model=LargeLanguageModels.text_davinci_003.name, sampling_temperature=0.8, @@ -85,9 +82,7 @@ class RequestModelBase(BasePage.RequestModel): task_instructions: str | None query_instructions: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None max_search_urls: int | None @@ -95,7 +90,7 @@ class RequestModelBase(BasePage.RequestModel): max_context_words: int | None scroll_jump: int | None - embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + embedding_model: EmbeddingModels.api_enum | None dense_weight: float | None = DocSearchRequest.__fields__[ "dense_weight" ].field_info @@ -216,7 +211,7 @@ def run_v2( self, request: "GoogleGPTPage.RequestModel", response: "GoogleGPTPage.ResponseModel", - ): + ) -> typing.Iterator[str | None]: model = LargeLanguageModels[request.selected_model] query_instructions = (request.query_instructions or "").strip() @@ -236,8 +231,8 @@ def run_v2( ) response.serp_results, links = get_links_from_serp_api( response.final_search_query, - search_type=request.serp_search_type, - search_location=request.serp_search_location, + search_type=SerpSearchType.from_api(request.serp_search_type), + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) # extract links & their corresponding titles link_titles = {item.url: f"{item.title} | {item.snippet}" for item in links} diff --git a/recipes/GoogleImageGen.py b/recipes/GoogleImageGen.py index a94e812ce..61488b82d 100644 --- a/recipes/GoogleImageGen.py +++ b/recipes/GoogleImageGen.py @@ -17,14 +17,14 @@ from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.serp_search import call_serp_api from daras_ai_v2.serp_search_locations import ( - serp_search_location_selectbox, GoogleSearchLocationMixin, + SerpSearchLocations, SerpSearchType, - SerpSearchLocation, + serp_search_location_selectbox, ) from daras_ai_v2.stable_diffusion import ( img2img, - Img2ImgModels, + ImageToImageModels, SD_IMG_MAX_SIZE, instruct_pix2pix, ) @@ -37,6 +37,7 @@ class GoogleImageGenPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/eb23c078-88da-11ee-aa86-02420a000165/web%20search%20render.png.png" workflow = Workflow.GOOGLE_IMAGE_GEN slug_versions = ["GoogleImageGen", "render-images-with-ai"] + sdk_method_name = "imageFromWebSearch" sane_defaults = dict( num_outputs=1, @@ -46,15 +47,15 @@ class GoogleImageGenPage(BasePage): sd_2_upscaling=False, seed=42, image_guidance_scale=1.2, - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, ) class RequestModel(GoogleSearchLocationMixin, BasePage.RequestModel): search_query: str text_prompt: str - selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None + selected_model: ImageToImageModels.api_enum | None negative_prompt: str | None @@ -112,8 +113,8 @@ def run(self, state: dict): serp_results = call_serp_api( request.search_query, - search_type=SerpSearchType.IMAGES, - search_location=request.serp_search_location, + search_type=SerpSearchType.images, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) image_urls = [ link @@ -152,7 +153,7 @@ def run(self, state: dict): yield "Generating Images..." - if request.selected_model == Img2ImgModels.instruct_pix2pix.name: + if request.selected_model == ImageToImageModels.instruct_pix2pix.name: state["output_images"] = instruct_pix2pix( prompt=request.text_prompt, num_outputs=request.num_outputs, @@ -185,7 +186,7 @@ def render_form_v2(self): """, key="search_query", ) - model_selector(Img2ImgModels) + model_selector(ImageToImageModels) gui.text_area( """ #### 👩‍💻 Prompt @@ -199,7 +200,7 @@ def render_usage_guide(self): youtube_video("rnjvtaYYe8g") def render_settings(self): - img_model_settings(Img2ImgModels, render_model_selector=False) + img_model_settings(ImageToImageModels, render_model_selector=False) serp_search_location_selectbox() def render_output(self): diff --git a/recipes/ImageSegmentation.py b/recipes/ImageSegmentation.py index f886c2b4e..c9e16d988 100644 --- a/recipes/ImageSegmentation.py +++ b/recipes/ImageSegmentation.py @@ -16,7 +16,7 @@ ) from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.image_segmentation import u2net, ImageSegmentationModels, dis +from daras_ai_v2.image_segmentation import ImageSegmentationModels, dis, u2net from daras_ai_v2.img_io import opencv_to_pil, pil_to_bytes from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.polygon_fitter import ( @@ -37,6 +37,7 @@ class ImageSegmentationPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/06fc595e-88db-11ee-b428-02420a000168/AI%20Background%20Remover.png.png" workflow = Workflow.IMAGE_SEGMENTATION slug_versions = ["ImageSegmentation", "remove-image-background-with-ai"] + sdk_method_name = "removeBackground" sane_defaults = { "mask_threshold": 0.5, @@ -50,9 +51,7 @@ class ImageSegmentationPage(BasePage): class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl - selected_model: ( - typing.Literal[tuple(e.name for e in ImageSegmentationModels)] | None - ) + selected_model: ImageSegmentationModels.api_enum | None mask_threshold: float | None rect_persepective_transform: bool | None diff --git a/recipes/Img2Img.py b/recipes/Img2Img.py index 97de89ab7..601a1d378 100644 --- a/recipes/Img2Img.py +++ b/recipes/Img2Img.py @@ -2,7 +2,7 @@ from daras_ai_v2.pydantic_validation import FieldHttpUrl import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -10,7 +10,7 @@ from daras_ai_v2.img_model_settings_widgets import img_model_settings from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.stable_diffusion import ( - Img2ImgModels, + ImageToImageModels, img2img, SD_IMG_MAX_SIZE, instruct_pix2pix, @@ -27,6 +27,7 @@ class Img2ImgPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/bcc9351a-88d9-11ee-bf6c-02420a000166/Edit%20an%20image%20with%20AI%201.png.png" workflow = Workflow.IMG_2_IMG slug_versions = ["Img2Img", "ai-photo-editor"] + sdk_method_name = "remixImage" sane_defaults = { "num_outputs": 1, @@ -45,12 +46,10 @@ class RequestModel(BasePage.RequestModel): input_image: FieldHttpUrl text_prompt: str | None - selected_model: typing.Literal[tuple(e.name for e in Img2ImgModels)] | None + selected_model: ImageToImageModels.api_enum | None selected_controlnet_model: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)]] - | typing.Literal[tuple(e.name for e in ControlNetModels)] - | None - ) + list[ControlNetModels.api_enum] | ControlNetModels.api_enum | None + ) = Field(**{"x-fern-type-name": "SelectedControlNetModels"}) negative_prompt: str | None num_outputs: int | None @@ -124,7 +123,7 @@ def render_description(self): ) def render_settings(self): - img_model_settings(Img2ImgModels) + img_model_settings(ImageToImageModels) def render_usage_guide(self): youtube_video("narcZNyuNAg") @@ -161,7 +160,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: yield "Generating Image..." - if request.selected_model == Img2ImgModels.instruct_pix2pix.name: + if request.selected_model == ImageToImageModels.instruct_pix2pix.name: state["output_images"] = instruct_pix2pix( prompt=request.text_prompt, num_outputs=request.num_outputs, @@ -205,7 +204,7 @@ def preview_description(self, state: dict) -> str: def get_raw_price(self, state: dict) -> int: selected_model = state.get("selected_model") match selected_model: - case Img2ImgModels.dall_e.name: + case ImageToImageModels.dall_e.name: unit_price = 20 case _: unit_price = 5 diff --git a/recipes/LetterWriter.py b/recipes/LetterWriter.py index ff39fb3aa..53a0296af 100644 --- a/recipes/LetterWriter.py +++ b/recipes/LetterWriter.py @@ -18,6 +18,7 @@ class LetterWriterPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.LETTER_WRITER slug_versions = ["LetterWriter"] + sdk_method_name = "" class RequestModel(BasePage.RequestModel): action_id: str @@ -46,6 +47,10 @@ class ResponseModel(BaseModel): generated_input_prompt: str final_prompt: str + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return {"x-fern-ignore": True} + def render_description(self): gui.write( """ diff --git a/recipes/Lipsync.py b/recipes/Lipsync.py index 6dae9320e..0a71e7ac8 100644 --- a/recipes/Lipsync.py +++ b/recipes/Lipsync.py @@ -8,7 +8,7 @@ from daras_ai_v2.base import BasePage from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.lipsync_api import run_wav2lip, run_sadtalker, LipsyncSettings -from daras_ai_v2.lipsync_settings_widgets import lipsync_settings, LipsyncModel +from daras_ai_v2.lipsync_settings_widgets import lipsync_settings, LipsyncModels from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.pydantic_validation import FieldHttpUrl @@ -19,7 +19,7 @@ def price_for_model(selected_model: str | None) -> float: - if selected_model == LipsyncModel.SadTalker.name: + if selected_model == LipsyncModels.SadTalker.name: multiplier = 2 else: multiplier = 1 @@ -31,11 +31,10 @@ class LipsyncPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f33e6332-88d8-11ee-89f9-02420a000169/Lipsync%20TTS.png.png" workflow = Workflow.LIPSYNC slug_versions = ["Lipsync"] + sdk_method_name = "lipsync" class RequestModel(LipsyncSettings, BasePage.RequestModel): - selected_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( - LipsyncModel.Wav2Lip.name - ) + selected_model: LipsyncModels.api_enum = LipsyncModels.Wav2Lip.name input_audio: FieldHttpUrl = None class ResponseModel(BaseModel): @@ -71,7 +70,7 @@ def render_form_v2(self): ) enum_selector( - LipsyncModel, + LipsyncModels, label="###### Lipsync Model", key="selected_model", use_selectbox=True, @@ -92,10 +91,10 @@ def run(self, state: dict) -> typing.Iterator[str | None]: else: max_frames = 250 - model = LipsyncModel[request.selected_model] + model = LipsyncModels[request.selected_model] yield f"Running {model.value}..." match model: - case LipsyncModel.Wav2Lip: + case LipsyncModels.Wav2Lip: state["output_video"], state["duration_sec"] = run_wav2lip( face=request.input_face, audio=request.input_audio, @@ -107,7 +106,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: ), max_frames=max_frames, ) - case LipsyncModel.SadTalker: + case LipsyncModels.SadTalker: state["output_video"], state["duration_sec"] = run_sadtalker( request.sadtalker_settings, face=request.input_face, diff --git a/recipes/LipsyncTTS.py b/recipes/LipsyncTTS.py index d557cd663..ef9f9b57d 100644 --- a/recipes/LipsyncTTS.py +++ b/recipes/LipsyncTTS.py @@ -6,7 +6,7 @@ import gooey_gui as gui from bots.models import Workflow from daras_ai_v2.enum_selector_widget import enum_selector -from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel +from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModels from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.text_to_speech_settings_widgets import ( @@ -23,6 +23,7 @@ class LipsyncTTSPage(LipsyncPage, TextToSpeechPage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/1acfa370-88d9-11ee-bf6c-02420a000166/Lipsync%20with%20audio%201.png.png" workflow = Workflow.LIPSYNC_TTS slug_versions = ["LipsyncTTS", "lipsync-maker"] + sdk_method_name = "lipsyncTTS" sane_defaults = { "elevenlabs_model": "eleven_multilingual_v2", @@ -31,9 +32,7 @@ class LipsyncTTSPage(LipsyncPage, TextToSpeechPage): } class RequestModel(LipsyncSettings, TextToSpeechPage.RequestModel): - selected_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( - LipsyncModel.Wav2Lip.name - ) + selected_model: LipsyncModels.api_enum = LipsyncModels.Wav2Lip.name class ResponseModel(BaseModel): audio_url: str | None @@ -76,7 +75,7 @@ def render_form_v2(self): ) enum_selector( - LipsyncModel, + LipsyncModels, label="###### Lipsync Model", key="selected_model", use_selectbox=True, diff --git a/recipes/ObjectInpainting.py b/recipes/ObjectInpainting.py index 3ec1f6f89..65f902ca4 100644 --- a/recipes/ObjectInpainting.py +++ b/recipes/ObjectInpainting.py @@ -33,6 +33,7 @@ class ObjectInpaintingPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/f07b731e-88d9-11ee-a658-02420a000163/W.I.3.png.png" workflow = Workflow.OBJECT_INPAINTING slug_versions = ["ObjectInpainting", "product-photo-background-generator"] + sdk_method_name = "productImage" sane_defaults = { "mask_threshold": 0.7, @@ -55,7 +56,7 @@ class RequestModel(BasePage.RequestModel): mask_threshold: float | None - selected_model: typing.Literal[tuple(e.name for e in InpaintingModels)] | None + selected_model: InpaintingModels.api_enum | None negative_prompt: str | None diff --git a/recipes/QRCodeGenerator.py b/recipes/QRCodeGenerator.py index cfa9ce013..7f61f8d76 100644 --- a/recipes/QRCodeGenerator.py +++ b/recipes/QRCodeGenerator.py @@ -8,7 +8,7 @@ from django.core.exceptions import ValidationError from django.core.validators import URLValidator from furl import furl -from pydantic import BaseModel +from pydantic import BaseModel, Field from pyzbar import pyzbar import gooey_gui as gui @@ -30,14 +30,13 @@ from daras_ai_v2.repositioning import reposition_object, repositioning_preview_widget from daras_ai_v2.safety_checker import safety_checker from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, + TextToImageModels, controlnet, ControlNetModels, - Img2ImgModels, + ImageToImageModels, Schedulers, ) from daras_ai_v2.vcard import VCARD -from recipes.EmailFaceInpainting import get_photo_for_email from recipes.SocialLookupEmail import get_profile_for_email from url_shortener.models import ShortenedURL from daras_ai_v2.enum_selector_widget import enum_multiselect @@ -58,6 +57,7 @@ class QRCodeGeneratorPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/03d6538e-88d5-11ee-ad97-02420a00016c/W.I.2.png.png" workflow = Workflow.QR_CODE slug_versions = ["art-qr-code", "qr", "qr-code"] + sdk_method_name = "qrCode" sane_defaults = dict( num_outputs=2, @@ -82,7 +82,9 @@ def __init__(self, *args, **kwargs): class RequestModel(BasePage.RequestModel): qr_code_data: str | None qr_code_input_image: FieldHttpUrl | None - qr_code_vcard: VCARD | None + qr_code_vcard: VCARD | None = Field( + title="VCard", **{"x-fern-type-name": "VCard"} + ) qr_code_file: FieldHttpUrl | None use_url_shortener: bool | None @@ -90,18 +92,14 @@ class RequestModel(BasePage.RequestModel): text_prompt: str negative_prompt: str | None image_prompt: str | None - image_prompt_controlnet_models: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None - ) + image_prompt_controlnet_models: list[ControlNetModels.api_enum] | None image_prompt_strength: float | None image_prompt_scale: float | None image_prompt_pos_x: float | None image_prompt_pos_y: float | None - selected_model: typing.Literal[tuple(e.name for e in Text2ImgModels)] | None - selected_controlnet_model: ( - list[typing.Literal[tuple(e.name for e in ControlNetModels)], ...] | None - ) + selected_model: TextToImageModels.api_enum | None + selected_controlnet_model: list[ControlNetModels.api_enum] | None output_width: int | None output_height: int | None @@ -111,7 +109,7 @@ class RequestModel(BasePage.RequestModel): num_outputs: int | None quality: int | None - scheduler: typing.Literal[tuple(e.name for e in Schedulers)] | None + scheduler: Schedulers.api_enum | None seed: int | None @@ -297,7 +295,7 @@ def render_settings(self): ) img_model_settings( - Img2ImgModels, + ImageToImageModels, show_scheduler=True, require_controlnet=True, extra_explanations={ @@ -486,7 +484,7 @@ def run(self, state: dict) -> typing.Iterator[str | None]: state["raw_images"] = raw_images = [] - yield f"Running {Text2ImgModels[request.selected_model].value}..." + yield f"Running {TextToImageModels[request.selected_model].label}..." if isinstance(request.selected_controlnet_model, str): request.selected_controlnet_model = [request.selected_controlnet_model] init_images = [image] * len(request.selected_controlnet_model) @@ -547,12 +545,14 @@ def preview_description(self, state: dict) -> str: """ def get_raw_price(self, state: dict) -> int: - selected_model = state.get("selected_model", Text2ImgModels.dream_shaper.name) + selected_model = state.get( + "selected_model", TextToImageModels.dream_shaper.name + ) total = 5 match selected_model: - case Text2ImgModels.deepfloyd_if.name: + case TextToImageModels.deepfloyd_if.name: total += 3 - case Text2ImgModels.dall_e.name: + case TextToImageModels.dall_e.name: total += 10 return total * state.get("num_outputs", 1) diff --git a/recipes/RelatedQnA.py b/recipes/RelatedQnA.py index 6372ce65e..f43c84097 100644 --- a/recipes/RelatedQnA.py +++ b/recipes/RelatedQnA.py @@ -8,10 +8,7 @@ LargeLanguageModels, ) from daras_ai_v2.serp_search import get_related_questions_from_serp_api -from daras_ai_v2.serp_search_locations import ( - SerpSearchLocation, - SerpSearchType, -) +from daras_ai_v2.serp_search_locations import SerpSearchLocations, SerpSearchType from recipes.DocSearch import render_doc_search_step, EmptySearchResults from recipes.GoogleGPT import GoogleGPTPage from recipes.RelatedQnADoc import render_qna_outputs @@ -28,6 +25,7 @@ class RelatedQnAPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/37b0ba22-88d6-11ee-b549-02420a000167/People%20also%20ask.png.png" workflow = Workflow.RELATED_QNA_MAKER slug_versions = ["related-qna-maker"] + sdk_method_name = "seoPeopleAlsoAsk" price = 75 @@ -36,8 +34,8 @@ class RelatedQnAPage(BasePage): max_context_words=200, scroll_jump=5, dense_weight=1.0, - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, ) class RequestModel(GoogleGPTPage.RequestModel): @@ -117,7 +115,7 @@ def run_v2( related_questions, ) = get_related_questions_from_serp_api( request.search_query, - search_location=request.serp_search_location, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) all_questions = [request.search_query] + related_questions[:9] diff --git a/recipes/RelatedQnADoc.py b/recipes/RelatedQnADoc.py index 3f8c2d2d8..8362a7046 100644 --- a/recipes/RelatedQnADoc.py +++ b/recipes/RelatedQnADoc.py @@ -7,10 +7,7 @@ from daras_ai_v2.language_model import LargeLanguageModels from daras_ai_v2.search_ref import CitationStyles from daras_ai_v2.serp_search import get_related_questions_from_serp_api -from daras_ai_v2.serp_search_locations import ( - SerpSearchLocation, - SerpSearchType, -) +from daras_ai_v2.serp_search_locations import SerpSearchLocations, SerpSearchType from daras_ai_v2.vector_search import render_sources_widget from recipes.DocSearch import DocSearchPage, render_doc_search_step, EmptySearchResults from recipes.GoogleGPT import render_output_with_refs, GoogleSearchMixin @@ -27,14 +24,15 @@ class RelatedQnADocPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.RELATED_QNA_MAKER_DOC slug_versions = ["related-qna-maker-doc"] + sdk_method_name = "seoPeopleAlsoAskDoc" price = 100 sane_defaults = dict( citation_style=CitationStyles.number.name, dense_weight=1.0, - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, ) class RequestModel(GoogleSearchMixin, DocSearchPage.RequestModel): @@ -111,7 +109,7 @@ def run_v2( related_questions, ) = get_related_questions_from_serp_api( request.search_query, - search_location=request.serp_search_location, + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) all_questions = [request.search_query] + related_questions[:9] diff --git a/recipes/SEOSummary.py b/recipes/SEOSummary.py index 1b4e65365..2274b4662 100644 --- a/recipes/SEOSummary.py +++ b/recipes/SEOSummary.py @@ -28,9 +28,9 @@ from daras_ai_v2.scrollable_html_widget import scrollable_html from daras_ai_v2.serp_search import get_links_from_serp_api from daras_ai_v2.serp_search_locations import ( - serp_search_settings, - SerpSearchLocation, + SerpSearchLocations, SerpSearchType, + serp_search_settings, ) from daras_ai_v2.settings import EXTERNAL_REQUEST_TIMEOUT_SEC from recipes.GoogleGPT import GoogleSearchMixin @@ -61,6 +61,7 @@ class SEOSummaryPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85f38b42-88d6-11ee-ad97-02420a00016c/Create%20SEO%20optimized%20content%20option%202.png.png" workflow = Workflow.SEO_SUMMARY slug_versions = ["SEOSummary", "seo-paragraph-generator"] + sdk_method_name = "seoContent" def preview_image(self, state: dict) -> str | None: return SEO_SUMMARY_DEFAULT_META_IMG @@ -73,8 +74,8 @@ def preview_description(self, state: dict) -> str: keywords="outdoor rugs,8x10 rugs,rug sizes,checkered rugs,5x7 rugs", title="Ruggable", company_url="https://ruggable.com", - serp_search_type=SerpSearchType.SEARCH, - serp_search_location=SerpSearchLocation.UNITED_STATES, + serp_search_type=SerpSearchType.search.name, + serp_search_location=SerpSearchLocations.UNITED_STATES.api_value, enable_html=False, selected_model=LargeLanguageModels.text_davinci_003.name, sampling_temperature=0.8, @@ -90,6 +91,9 @@ def preview_description(self, state: dict) -> str: ) class RequestModelBase(BaseModel): + class Config: + use_enum_values = True + search_query: str keywords: str title: str @@ -99,9 +103,7 @@ class RequestModelBase(BaseModel): enable_html: bool | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None max_search_urls: int | None @@ -274,8 +276,8 @@ def run(self, state: dict) -> typing.Iterator[str | None]: serp_results, links = get_links_from_serp_api( request.search_query, - search_type=request.serp_search_type, - search_location=request.serp_search_location, + search_type=SerpSearchType.from_api(request.serp_search_type), + search_location=SerpSearchLocations.from_api(request.serp_search_location), ) state["serp_results"] = serp_results state["search_urls"] = [it.url for it in links] @@ -314,8 +316,8 @@ def _crosslink_keywords(output_content, request): all_results = map_parallel( lambda keyword: get_links_from_serp_api( f"site:{host} {keyword}", - search_type=request.serp_search_type, - search_location=request.serp_search_location, + search_type=SerpSearchType.from_api(request.serp_search_type), + search_location=SerpSearchLocations.from_api(request.serp_search_location), )[1], relevant_keywords, ) diff --git a/recipes/SmartGPT.py b/recipes/SmartGPT.py index fa4066fc7..7141132c8 100644 --- a/recipes/SmartGPT.py +++ b/recipes/SmartGPT.py @@ -1,7 +1,7 @@ import typing import jinja2.sandbox -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -29,6 +29,7 @@ class SmartGPTPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/ffd24ad8-88d7-11ee-a658-02420a000163/SmartGPT.png.png" workflow = Workflow.SMART_GPT slug_versions = ["SmartGPT"] + sdk_method_name = "smartGPT" price = 20 class RequestModelBase(BasePage.RequestModel): @@ -38,9 +39,7 @@ class RequestModelBase(BasePage.RequestModel): reflexion_prompt: str | None dera_prompt: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/SocialLookupEmail.py b/recipes/SocialLookupEmail.py index bc3a0dea1..f1e3a3bd5 100644 --- a/recipes/SocialLookupEmail.py +++ b/recipes/SocialLookupEmail.py @@ -2,7 +2,7 @@ import typing import requests -from pydantic import BaseModel +from pydantic import BaseModel, Field import gooey_gui as gui from bots.models import Workflow @@ -32,6 +32,7 @@ class SocialLookupEmailPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5fbd475a-88d7-11ee-aac9-02420a00016b/personalized%20email.png.png" workflow = Workflow.SOCIAL_LOOKUP_EMAIL slug_versions = ["SocialLookupEmail", "email-writer-with-profile-lookup"] + sdk_method_name = "personalizeEmail" sane_defaults = { "selected_model": LargeLanguageModels.gpt_4.name, @@ -54,9 +55,7 @@ class RequestModelBase(BasePage.RequestModel): # domain: str | None # key_words: str | None - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None class RequestModel(LanguageModelSettings, RequestModelBase): pass diff --git a/recipes/Text2Audio.py b/recipes/Text2Audio.py index f3199a95b..302bf236f 100644 --- a/recipes/Text2Audio.py +++ b/recipes/Text2Audio.py @@ -1,6 +1,6 @@ import typing -from enum import Enum +from daras_ai_v2.custom_enum import GooeyEnum from daras_ai_v2.pydantic_validation import FieldHttpUrl from pydantic import BaseModel @@ -18,7 +18,7 @@ DEFAULT_TEXT2AUDIO_META_IMG = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/85cf8ea4-9457-11ee-bd77-02420a0001ce/Text%20guided%20audio.jpg.png" -class Text2AudioModels(Enum): +class Text2AudioModels(GooeyEnum): audio_ldm = "AudioLDM (CVSSP)" @@ -32,6 +32,7 @@ class Text2AudioPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/a4481d58-88d9-11ee-aa86-02420a000165/Text%20guided%20audio%20generator.png.png" workflow = Workflow.TEXT_2_AUDIO slug_versions = ["text2audio"] + sdk_method_name = "textToMusic" sane_defaults = dict( seed=42, @@ -50,13 +51,11 @@ class RequestModel(BasePage.RequestModel): seed: int | None sd_2_upscaling: bool | None - selected_models: ( - list[typing.Literal[tuple(e.name for e in Text2AudioModels)]] | None - ) + selected_models: list[Text2AudioModels.api_enum] | None class ResponseModel(BaseModel): output_audios: dict[ - typing.Literal[tuple(e.name for e in Text2AudioModels)], + Text2AudioModels.api_enum, list[FieldHttpUrl], ] diff --git a/recipes/TextToSpeech.py b/recipes/TextToSpeech.py index dbc877a17..0488a6fba 100644 --- a/recipes/TextToSpeech.py +++ b/recipes/TextToSpeech.py @@ -31,7 +31,7 @@ class TextToSpeechSettings(BaseModel): - tts_provider: typing.Literal[tuple(e.name for e in TextToSpeechProviders)] | None + tts_provider: TextToSpeechProviders.api_enum | None uberduck_voice_name: str | None uberduck_speaking_rate: float | None @@ -55,8 +55,8 @@ class TextToSpeechSettings(BaseModel): azure_voice_name: str | None - openai_voice_name: OpenAI_TTS_Voices.api_choices | None - openai_tts_model: OpenAI_TTS_Models.api_choices | None + openai_voice_name: OpenAI_TTS_Voices.api_enum | None + openai_tts_model: OpenAI_TTS_Models.api_enum | None class TextToSpeechPage(BasePage): @@ -69,6 +69,7 @@ class TextToSpeechPage(BasePage): "text2speech", "compare-text-to-speech-engines", ] + sdk_method_name = "textToSpeech" sane_defaults = { "tts_provider": TextToSpeechProviders.GOOGLE_TTS.value, diff --git a/recipes/Translation.py b/recipes/Translation.py index b8061beb2..19ce68795 100644 --- a/recipes/Translation.py +++ b/recipes/Translation.py @@ -38,13 +38,14 @@ class TranslationPage(BasePage): title = "Compare AI Translations" workflow = Workflow.TRANSLATION slug_versions = ["translate", "translation", "compare-ai-translation"] + sdk_method_name = "translate" class RequestModelBase(BasePage.RequestModel): texts: list[str] = Field([]) - selected_model: ( - typing.Literal[tuple(e.name for e in TranslationModels)] - ) | None = Field(TranslationModels.google.name) + selected_model: TranslationModels.api_enum | None = Field( + TranslationModels.google.name + ) class RequestModel(TranslationOptions, RequestModelBase): pass diff --git a/recipes/VideoBots.py b/recipes/VideoBots.py index 77b266c26..e6637c0d3 100644 --- a/recipes/VideoBots.py +++ b/recipes/VideoBots.py @@ -79,7 +79,7 @@ language_model_selector, LanguageModelSettings, ) -from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModel +from daras_ai_v2.lipsync_api import LipsyncSettings, LipsyncModels from daras_ai_v2.lipsync_settings_widgets import lipsync_settings from daras_ai_v2.loom_video_widget import youtube_video from daras_ai_v2.prompt_vars import render_prompt_vars @@ -136,6 +136,9 @@ class VideoBotsPage(BasePage): workflow = Workflow.VIDEO_BOTS slug_versions = ["video-bots", "bots", "copilot"] + sdk_group_name = "copilot" + sdk_method_name = "completion" + functions_in_settings = False sane_defaults = { @@ -193,9 +196,7 @@ class RequestModelBase(BasePage.RequestModel): bot_script: str | None # llm model - selected_model: ( - typing.Literal[tuple(e.name for e in LargeLanguageModels)] | None - ) + selected_model: LargeLanguageModels.api_enum | None document_model: str | None = Field( title="🩻 Photo / Document Intelligence", description="When your copilot users upload a photo or pdf, what kind of document are they mostly likely to upload? " @@ -211,15 +212,15 @@ class RequestModelBase(BasePage.RequestModel): max_context_words: int | None scroll_jump: int | None - embedding_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + embedding_model: EmbeddingModels.api_enum | None dense_weight: float | None = DocSearchRequest.__fields__[ "dense_weight" ].field_info - citation_style: typing.Literal[tuple(e.name for e in CitationStyles)] | None + citation_style: CitationStyles.api_enum | None use_url_shortener: bool | None - asr_model: typing.Literal[tuple(e.name for e in AsrModels)] | None = Field( + asr_model: AsrModels.api_enum | None = Field( title="Speech-to-Text Provider", description="Choose a model to transcribe incoming audio messages to text.", ) @@ -228,9 +229,7 @@ class RequestModelBase(BasePage.RequestModel): description="Choose a language to transcribe incoming audio messages to text.", ) - translation_model: ( - typing.Literal[tuple(e.name for e in TranslationModels)] | None - ) + translation_model: TranslationModels.api_enum | None user_language: str | None = Field( title="User Language", description="Choose a language to translate incoming text & audio messages to English and responses back to your selected language. Useful for low-resource languages.", @@ -249,9 +248,7 @@ class RequestModelBase(BasePage.RequestModel): """, ) - lipsync_model: typing.Literal[tuple(e.name for e in LipsyncModel)] = ( - LipsyncModel.Wav2Lip.name - ) + lipsync_model: LipsyncModels.api_enum = LipsyncModels.Wav2Lip.name tools: list[LLMTools] | None = Field( title="🛠️ Tools", @@ -286,6 +283,13 @@ class ResponseModel(BaseModel): finish_reason: list[str] | None + @classmethod + def get_openapi_extra(cls) -> dict[str, typing.Any]: + return { + "x-fern-sdk-group-name": cls.sdk_group_name, + "x-fern-sdk-method-name": cls.sdk_method_name, + } + def preview_image(self, state: dict) -> str | None: return DEFAULT_COPILOT_META_IMG @@ -372,7 +376,7 @@ def render_form_v2(self): key="input_face", ) enum_selector( - LipsyncModel, + LipsyncModels, label="###### Lipsync Model", key="lipsync_model", use_selectbox=True, diff --git a/recipes/asr_page.py b/recipes/asr_page.py index f68be1a08..509540b5a 100644 --- a/recipes/asr_page.py +++ b/recipes/asr_page.py @@ -38,19 +38,18 @@ class AsrPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/5fb7e5f6-88d9-11ee-aa86-02420a000165/Speech.png.png" workflow = Workflow.ASR slug_versions = ["asr", "speech"] + sdk_method_name = "speechRecognition" sane_defaults = dict(output_format=AsrOutputFormat.text.name) class RequestModelBase(BasePage.RequestModel): documents: list[FieldHttpUrl] - selected_model: typing.Literal[tuple(e.name for e in AsrModels)] | None + selected_model: AsrModels.api_enum | None language: str | None - translation_model: ( - typing.Literal[tuple(e.name for e in TranslationModels)] | None - ) + translation_model: TranslationModels.api_enum | None - output_format: typing.Literal[tuple(e.name for e in AsrOutputFormat)] | None + output_format: AsrOutputFormat.api_enum | None google_translate_target: str | None = Field( deprecated=True, diff --git a/recipes/embeddings_page.py b/recipes/embeddings_page.py index e65580681..d3cff0244 100644 --- a/recipes/embeddings_page.py +++ b/recipes/embeddings_page.py @@ -17,11 +17,12 @@ class EmbeddingsPage(BasePage): explore_image = "https://storage.googleapis.com/dara-c1b52.appspot.com/daras_ai/media/aeb83ee8-889e-11ee-93dc-02420a000143/Youtube%20transcripts%20GPT%20extractions.png.png" workflow = Workflow.EMBEDDINGS slug_versions = ["embeddings", "embed", "text-embedings"] + sdk_method_name = "embed" price = 1 class RequestModel(BasePage.RequestModel): texts: list[str] - selected_model: typing.Literal[tuple(e.name for e in EmbeddingModels)] | None + selected_model: EmbeddingModels.api_enum | None class ResponseModel(BaseModel): embeddings: list[list[float]] diff --git a/routers/api.py b/routers/api.py index 73c5ec7b9..02b8e357a 100644 --- a/routers/api.py +++ b/routers/api.py @@ -5,6 +5,7 @@ import typing import gooey_gui as gui +from fastapi import Query from fastapi import Depends from fastapi import Form from fastapi import HTTPException @@ -152,10 +153,13 @@ def script_to_api(page_cls: typing.Type[BasePage]): operation_id=page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v2 sync)", + openapi_extra={"x-fern-ignore": True}, + include_in_schema=False, ) def run_api_json( request: Request, page_request: request_model, + example_id: str | None = None, user: AppUser = Depends(api_auth_header), ): result, sr = submit_api_call( @@ -200,11 +204,13 @@ def run_api_form( name=page_cls.title + " (v3 async)", tags=[page_cls.title], status_code=202, + openapi_extra=page_cls.get_openapi_extra(), ) def run_api_json_async( request: Request, response: Response, page_request: request_model, + example_id: str | None = Query(default=None), user: AppUser = Depends(api_auth_header), ): result, sr = submit_api_call( @@ -255,6 +261,7 @@ def run_api_form_async( operation_id="status__" + page_cls.slug_versions[0], tags=[page_cls.title], name=page_cls.title + " (v3 status)", + openapi_extra={"x-fern-ignore": True}, ) def get_run_status( run_id: str, @@ -426,11 +433,16 @@ class BalanceResponse(BaseModel): balance: int = Field(description="Current balance in credits") -@app.get("/v1/balance/", response_model=BalanceResponse, tags=["Misc"]) +@app.get( + "/v1/balance/", + response_model=BalanceResponse, + tags=["Misc"], + openapi_extra={"x-fern-sdk-method-name": "getBalance"}, +) def get_balance(user: AppUser = Depends(api_auth_header)): return BalanceResponse(balance=user.balance) -@app.get("/status") +@app.get("/status", openapi_extra={"x-fern-ignore": True}) async def health(): return "OK" diff --git a/routers/bots_api.py b/routers/bots_api.py index 021caae4b..c99dadec0 100644 --- a/routers/bots_api.py +++ b/routers/bots_api.py @@ -85,6 +85,7 @@ class CreateStreamResponse(BaseModel): operation_id=VideoBotsPage.slug_versions[0] + "__stream_create", tags=["Copilot Integrations"], name="Copilot Integrations Create Stream", + openapi_extra={"x-fern-ignore": True}, ) def stream_create(request: CreateStreamRequest, response: Response): request_id = str(uuid.uuid4()) @@ -173,6 +174,7 @@ class StreamError(BaseModel): operation_id=VideoBotsPage.slug_versions[0] + "__stream", tags=["Copilot Integrations"], name="Copilot integrations Stream Response", + openapi_extra={"x-fern-ignore": True}, ) def stream_response(request_id: str): r = get_redis_cache().getdel(f"gooey/stream-init/v1/{request_id}") diff --git a/routers/broadcast_api.py b/routers/broadcast_api.py index cbf1ca07c..82cf8ea59 100644 --- a/routers/broadcast_api.py +++ b/routers/broadcast_api.py @@ -51,6 +51,7 @@ class BotBroadcastRequestModel(BaseModel): operation_id=VideoBotsPage.slug_versions[0] + "__broadcast", tags=["Misc"], name=f"Send Broadcast Message", + openapi_extra={"x-fern-ignore": True}, ) @app.post( f"/v2/{VideoBotsPage.slug_versions[0]}/broadcast/send", diff --git a/scripts/init_self_hosted_pricing.py b/scripts/init_self_hosted_pricing.py index 68bd63512..4cb5ba224 100644 --- a/scripts/init_self_hosted_pricing.py +++ b/scripts/init_self_hosted_pricing.py @@ -2,12 +2,9 @@ from daras_ai_v2.gpu_server import build_queue_name from daras_ai_v2.stable_diffusion import ( - Text2ImgModels, - Img2ImgModels, + TextToImageModels, + ImageToImageModels, InpaintingModels, - text2img_model_ids, - img2img_model_ids, - inpaint_model_ids, ) from recipes.DeforumSD import AnimationModels from usage_costs.models import ModelPricing @@ -17,21 +14,15 @@ def run(): - for model in AnimationModels: - add_model(model.value, model.name) - for model_enum, model_ids in [ - (Text2ImgModels, text2img_model_ids), - (Img2ImgModels, img2img_model_ids), - (InpaintingModels, inpaint_model_ids), + for model_enum in [ + AnimationModels, + TextToImageModels, + ImageToImageModels, + InpaintingModels, ]: for m in model_enum: - if "dall_e" in m.name: - continue - try: - add_model(model_ids[m], m.name) - except KeyError: - pass - + if "dall_e" not in m.name and m.model_id: + add_model(m.model_id, m.name) add_model("wav2lip_gan.pth", "wav2lip") add_model("SadTalker_V0.0.2_512.safetensors", "sadtalker") diff --git a/scripts/run_all_diffusion.py b/scripts/run_all_diffusion.py index 183260c5d..c6132b953 100644 --- a/scripts/run_all_diffusion.py +++ b/scripts/run_all_diffusion.py @@ -20,9 +20,9 @@ from daras_ai_v2.stable_diffusion import ( controlnet, ControlNetModels, - Img2ImgModels, + ImageToImageModels, text2img, - Text2ImgModels, + TextToImageModels, img2img, instruct_pix2pix, sd_upscale, @@ -34,7 +34,7 @@ # def fn(): # text2img( -# selected_model=Img2ImgModels.sd_1_5.name, +# selected_model=ImageToImageModels.sd_1_5.name, # prompt=get_random_string(100, string.ascii_letters), # num_outputs=1, # num_inference_steps=1, @@ -45,7 +45,7 @@ # # r = requests.get(GpuEndpoints.sd_multi / "magic") # # raise_for_status(r) # # img2img( -# # selected_model=Img2ImgModels.sd_1_5.name, +# # selected_model=ImageToImageModels.sd_1_5.name, # # prompt=get_random_string(100, string.ascii_letters), # # num_outputs=1, # # init_image=random_img, @@ -55,7 +55,7 @@ # # ) # # controlnet( # # selected_controlnet_model=ControlNetModels.sd_controlnet_depth.name, -# # selected_model=Img2ImgModels.sd_1_5.name, +# # selected_model=ImageToImageModels.sd_1_5.name, # # prompt=get_random_string(100, string.ascii_letters), # # num_outputs=1, # # init_image=random_img, @@ -72,11 +72,11 @@ # exit() tasks = [] -for model in Img2ImgModels: +for model in ImageToImageModels: if model in [ - Img2ImgModels.instruct_pix2pix, - Img2ImgModels.dall_e, - Img2ImgModels.jack_qiao, + ImageToImageModels.instruct_pix2pix, + ImageToImageModels.dall_e, + ImageToImageModels.jack_qiao, ]: continue print(model) @@ -96,7 +96,7 @@ ) for controlnet_model in ControlNetModels: if model in [ - Img2ImgModels.sd_2, + ImageToImageModels.sd_2, ]: continue print(controlnet_model) @@ -115,10 +115,10 @@ ) ) -for model in Text2ImgModels: +for model in TextToImageModels: if model in [ - Text2ImgModels.dall_e, - Text2ImgModels.jack_qiao, + TextToImageModels.dall_e, + TextToImageModels.jack_qiao, ]: continue print(model) diff --git a/usage_costs/models.py b/usage_costs/models.py index 8941d1cd7..2d113e2f9 100644 --- a/usage_costs/models.py +++ b/usage_costs/models.py @@ -58,14 +58,14 @@ class ModelProvider(models.IntegerChoices): def get_model_choices(): from daras_ai_v2.language_model import LargeLanguageModels from recipes.DeforumSD import AnimationModels - from daras_ai_v2.stable_diffusion import Text2ImgModels, Img2ImgModels + from daras_ai_v2.stable_diffusion import TextToImageModels, ImageToImageModels return ( [(api.name, api.value) for api in LargeLanguageModels] + [(model.name, model.label) for model in AnimationModels] - + [(model.name, model.value) for model in Text2ImgModels] - + [(model.name, model.value) for model in Img2ImgModels] - + [(model.name, model.value) for model in InpaintingModels] + + [(model.name, model.label) for model in TextToImageModels] + + [(model.name, model.label) for model in ImageToImageModels] + + [(model.name, model.label) for model in InpaintingModels] + [("wav2lip", "LipSync (wav2lip)")] + [("sadtalker", "LipSync (sadtalker)")] )