Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAPI revisions for compatibility with Fern SDK generator #422

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6442d2d
Refactor auth code to output auth scheme in OpenAPI spec
nikochiko Jul 29, 2024
dc46d6c
Merge branch 'master' into bearer-openapi
nikochiko Jul 29, 2024
1b5f46f
Add openapi params for fern bearer auth, hide healthcheck from fern
nikochiko Aug 1, 2024
63deccf
Add x-fern-sdk-return-value for all status routes
nikochiko Aug 1, 2024
53a8036
fern: ignore v2 sync APIs
nikochiko Aug 1, 2024
e318e0b
add query param for example_id to v2 sync and v3 async
nikochiko Aug 1, 2024
88dd5b5
Add method-name and group-name to openapi schema
nikochiko Aug 2, 2024
6ac8648
Merge branch 'master' into openapi-fern-revision
nikochiko Sep 4, 2024
4aa6f68
Add SDK names and OpenAPI extras for endpoints
nikochiko Sep 12, 2024
14344aa
Fix get_openapi_extra method on BasePage to use sdk_method_name
nikochiko Sep 12, 2024
0c715a4
fix method name for get_balance endpoint
nikochiko Sep 12, 2024
c0ac5d9
fix: ignore broadcast and bot APIs for SDK generation
nikochiko Sep 12, 2024
9d99420
use GooeyEnum for LargeLanguageModels with api_enum() method to gener…
nikochiko Sep 12, 2024
e1edb2c
fix sdk method name for copilot endpoint
nikochiko Sep 12, 2024
325cfd2
Revert "use GooeyEnum for LargeLanguageModels with api_enum() method …
nikochiko Sep 12, 2024
48abd59
use api_enum method on GooeyEnum class, use it with LargeLanguageModels
nikochiko Sep 12, 2024
90fad12
use .api_enum in ResponseModel for CompareLLM
nikochiko Sep 12, 2024
1f5150e
Use .api_enum with LipsyncModel and rename LipsyncModel -> LipsyncModels
nikochiko Sep 12, 2024
146c771
use GooeyEnum for AnimationModels
nikochiko Sep 12, 2024
e8a342b
fix: AsrModels to use GooeyEnum
nikochiko Sep 12, 2024
cbb3a09
Use GooeyEnum for EmbeddingModels
nikochiko Sep 13, 2024
76fe9aa
Refactor GooeyEnum to separate .name and .api_value
nikochiko Sep 13, 2024
aacf0b8
Use GooeyEnum for SerpSearchType & SerpSearchLocation
nikochiko Sep 13, 2024
3f8d305
Add title for QRCode VCard field to VCard
nikochiko Sep 13, 2024
83cacc6
fix serp types for SEOSummary recipe
nikochiko Sep 13, 2024
bee12c4
Use Enum for ResponseFormatType
nikochiko Sep 13, 2024
c8f296d
Use GooeyEnum for ControlNetModels
nikochiko Sep 13, 2024
8509c61
Use GooeyEnum for Text2ImgModels, Img2ImgModels, UpscalerModels
nikochiko Sep 13, 2024
4f60610
use GooeyEnum for translation models
nikochiko Sep 13, 2024
1f70c0a
Use GooeyEnum for segmentation models
nikochiko Sep 13, 2024
00b6bfc
Use GooeyEnum for AsrOutputFormat
nikochiko Sep 13, 2024
d38fc4f
Use GooeyEnum for CitationStyles
nikochiko Sep 13, 2024
dbe84f3
Use GooeyEnum for Scheduler and TextToSpeechProviders
nikochiko Sep 13, 2024
3192e76
Use GooeyEnum for text2audio models and combine document chain type
nikochiko Sep 13, 2024
25fe3f2
Rename Text2Img->TextToImage and Img2Img->ImageToImage for parity wit…
nikochiko Sep 13, 2024
4c46fe6
Rename for parity with SDK
nikochiko Sep 13, 2024
55e5ea9
x-fern-type-name for VCard field in QRCode
nikochiko Sep 13, 2024
e48b197
more renaming and remove api_choices in favor of api_enum
nikochiko Sep 13, 2024
a38ad6a
Merge branch 'master' into openapi-fern-revision
nikochiko Sep 13, 2024
25abd67
Fix defaults for serp_search_type and serp_search_location
nikochiko Sep 13, 2024
d75faec
remove debug print statements
nikochiko Sep 13, 2024
0713227
Revert "Add openapi params for fern bearer auth, hide healthcheck fro…
nikochiko Sep 16, 2024
9a5f74b
Revert "Refactor auth code to output auth scheme in OpenAPI spec"
nikochiko Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os.path
import tempfile
import typing
from enum import Enum

import gooey_gui as gui
import requests
Expand All @@ -14,6 +13,7 @@
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.azure_asr import azure_asr
from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.exceptions import (
raise_for_status,
UserError,
Expand Down Expand Up @@ -204,7 +204,7 @@
GHANA_NLP_MAXLEN = 500


class AsrModels(Enum):
class AsrModels(GooeyEnum):
whisper_large_v2 = "Whisper Large v2 (openai)"
whisper_large_v3 = "Whisper Large v3 (openai)"
whisper_hindi_large_v2 = "Whisper Hindi Large v2 (Bhashini)"
Expand Down Expand Up @@ -277,7 +277,7 @@ class AsrOutputJson(typing_extensions.TypedDict):
chunks: typing_extensions.NotRequired[list[AsrChunk]]


class AsrOutputFormat(Enum):
class AsrOutputFormat(GooeyEnum):
text = "Text"
json = "JSON"
srt = "SRT"
Expand All @@ -290,7 +290,7 @@ class TranslationModel(typing.NamedTuple):
supports_auto_detect: bool = False


class TranslationModels(TranslationModel, Enum):
class TranslationModels(TranslationModel, GooeyEnum):
google = TranslationModel(
label="Google Translate",
supports_glossary=True,
Expand Down
10 changes: 10 additions & 0 deletions daras_ai_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class BasePage:
title: str
workflow: Workflow
slug_versions: list[str]
sdk_method_name: str

sane_defaults: dict = {}

Expand All @@ -144,6 +145,9 @@ class BasePage:
)

class RequestModel(BaseModel):
class Config:
use_enum_values = True

functions: list[RecipeFunction] | None = Field(
title="🧩 Developer Tools and Functions",
)
Expand Down Expand Up @@ -317,6 +321,12 @@ def sentry_event_set_user(self, event, hint):
}
return event

@classmethod
def get_openapi_extra(cls) -> dict[str, typing.Any]:
return {
"x-fern-sdk-method-name": cls.sdk_method_name,
}

def refresh_state(self):
sr = self.current_sr
channel = self.realtime_channel_name(sr.run_id, sr.uid)
Expand Down
47 changes: 42 additions & 5 deletions daras_ai_v2/custom_enum.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
import functools
import typing
from enum import Enum

import typing_extensions


def cached_classmethod(func: typing.Callable):
"""
This cache is a hack to get around a bug where
dynamic Enums with the same name cause a crash
when generating the OpenAPI spec.
"""

@functools.wraps(func)
def wrapper(cls):
if not hasattr(cls, "_cached_classmethod"):
cls._cached_classmethod = {}
if id(func) not in cls._cached_classmethod:
cls._cached_classmethod[id(func)] = func(cls)

return cls._cached_classmethod[id(func)]

return wrapper


T = typing.TypeVar("T", bound="GooeyEnum")


class GooeyEnum(Enum):
@property
def api_value(self):
# api_value is usually the name
return self.name

@classmethod
def db_choices(cls):
return [(e.db_value, e.label) for e in cls]
Expand All @@ -20,12 +46,23 @@ def from_db(cls, db_value) -> typing_extensions.Self:

@classmethod
@property
def api_choices(cls):
return typing.Literal[tuple(e.name for e in cls)]
@cached_classmethod
def api_enum(cls):
"""
Enum that is useful as a type in API requests.

Maps `api_value`->`api_value` (default: `name`->`name`)
for all values.

The title (same as the Enum class name) will be
used as the new Enum's title. This will be passed
on to the OpenAPI schema and the generated SDK.
"""
return Enum(cls.__name__, {e.api_value: e.api_value for e in cls})

@classmethod
def from_api(cls, name: str) -> typing_extensions.Self:
def from_api(cls, api_value: str) -> typing_extensions.Self:
for e in cls:
if e.name == name:
if e.api_value == api_value:
return e
raise ValueError(f"Invalid {cls.__name__} {name=}")
raise ValueError(f"Invalid {cls.__name__} {api_value=}")
12 changes: 4 additions & 8 deletions daras_ai_v2/embedding_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import hashlib
import io
import typing
from enum import Enum
from functools import partial

import numpy as np
Expand All @@ -13,6 +12,7 @@
from jinja2.lexer import whitespace_re
from loguru import logger

from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.gpu_server import call_celery_task
from daras_ai_v2.language_model import get_openai_client
from daras_ai_v2.redis_cache import (
Expand All @@ -25,7 +25,7 @@ class EmbeddingModel(typing.NamedTuple):
label: str


class EmbeddingModels(Enum):
class EmbeddingModels(EmbeddingModel, GooeyEnum):
openai_3_large = EmbeddingModel(
model_id=("openai-text-embedding-3-large-prod-ca-1", "text-embedding-3-large"),
label="Text Embedding 3 Large (OpenAI)",
Expand Down Expand Up @@ -65,12 +65,8 @@ class EmbeddingModels(Enum):
)

@property
def model_id(self) -> typing.Iterable[str] | str:
return self.value.model_id

@property
def label(self) -> str:
return self.value.label
def db_value(self):
return self.name

@classmethod
def get(cls, key, default=None):
Expand Down
5 changes: 2 additions & 3 deletions daras_ai_v2/image_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from enum import Enum

import requests

from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.gpu_server import (
call_celery_task_outfile,
)


class ImageSegmentationModels(Enum):
class ImageSegmentationModels(str, GooeyEnum):
dis = "Dichotomous Image Segmentation"
u2net = "U²-Net"

Expand Down
53 changes: 26 additions & 27 deletions daras_ai_v2/img_model_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from daras_ai_v2.enum_selector_widget import enum_selector, enum_multiselect
from daras_ai_v2.stable_diffusion import (
Text2ImgModels,
TextToImageModels,
InpaintingModels,
Img2ImgModels,
ImageToImageModels,
ControlNetModels,
controlnet_model_explanations,
Schedulers,
)

Expand Down Expand Up @@ -37,10 +36,10 @@ def img_model_settings(
negative_prompt_setting(selected_model)

num_outputs_setting(selected_model)
if models_enum is not Img2ImgModels:
if models_enum is not ImageToImageModels:
output_resolution_setting()

if models_enum is Text2ImgModels:
if models_enum is TextToImageModels:
sd_2_upscaling_setting()

col1, col2 = gui.columns(2)
Expand All @@ -49,11 +48,11 @@ def img_model_settings(
guidance_scale_setting(selected_model)

with col2:
if models_enum is Img2ImgModels and not gui.session_state.get(
if models_enum is ImageToImageModels and not gui.session_state.get(
"selected_controlnet_model"
):
prompt_strength_setting(selected_model)
if selected_model == Img2ImgModels.instruct_pix2pix.name:
if selected_model == ImageToImageModels.instruct_pix2pix.name:
instruct_pix2pix_settings()

if show_scheduler:
Expand All @@ -73,10 +72,10 @@ def model_selector(
high_explanation: str = "At {high} the control nets will be applied tightly to the prompted visual, possibly overriding the prompt",
):
controlnet_unsupported_models = [
Img2ImgModels.instruct_pix2pix.name,
Img2ImgModels.dall_e.name,
Img2ImgModels.jack_qiao.name,
Img2ImgModels.sd_2.name,
ImageToImageModels.instruct_pix2pix.name,
ImageToImageModels.dall_e.name,
ImageToImageModels.jack_qiao.name,
ImageToImageModels.sd_2.name,
]
col1, col2 = gui.columns(2)
with col1:
Expand All @@ -96,12 +95,12 @@ def model_selector(
"""
)
if (
models_enum is Img2ImgModels
models_enum is ImageToImageModels
and gui.session_state.get("selected_model") in controlnet_unsupported_models
):
if "selected_controlnet_model" in gui.session_state:
gui.session_state["selected_controlnet_model"] = None
elif models_enum is Img2ImgModels:
elif models_enum is ImageToImageModels:
enum_multiselect(
ControlNetModels,
label=controlnet_explanation,
Expand Down Expand Up @@ -130,9 +129,7 @@ def controlnet_settings(
if not models:
return

if extra_explanations is None:
extra_explanations = {}
explanations = controlnet_model_explanations | extra_explanations
extra_explanations = extra_explanations or {}

state_values = gui.session_state.get("controlnet_conditioning_scale", [])
new_values = []
Expand All @@ -157,7 +154,9 @@ def controlnet_settings(
pass
new_values.append(
controlnet_weight_setting(
selected_controlnet_model=model, explanations=explanations, key=key
selected_controlnet_model=model,
extra_explanations=extra_explanations,
key=key,
),
)
gui.session_state["controlnet_conditioning_scale"] = new_values
Expand All @@ -166,13 +165,13 @@ def controlnet_settings(
def controlnet_weight_setting(
*,
selected_controlnet_model: str,
explanations: dict[ControlNetModels, str],
extra_explanations: dict[ControlNetModels, str],
key: str = "controlnet_conditioning_scale",
):
model = ControlNetModels[selected_controlnet_model]
return gui.slider(
label=f"""
{explanations[model]}.
{extra_explanations.get(model, model.explanation)}.
""",
key=key,
min_value=CONTROLNET_CONDITIONING_SCALE_RANGE[0],
Expand Down Expand Up @@ -215,7 +214,7 @@ def quality_setting(selected_models=None):
return
if any(
[
selected_model in [Text2ImgModels.dall_e_3.name]
selected_model in [TextToImageModels.dall_e_3.name]
for selected_model in selected_models
]
):
Expand Down Expand Up @@ -375,8 +374,8 @@ def sd_2_upscaling_setting():

def scheduler_setting(selected_model: str = None):
if selected_model in [
Text2ImgModels.dall_e.name,
Text2ImgModels.jack_qiao,
TextToImageModels.dall_e.name,
TextToImageModels.jack_qiao,
]:
return
enum_selector(
Expand All @@ -395,8 +394,8 @@ def scheduler_setting(selected_model: str = None):

def guidance_scale_setting(selected_model: str = None):
if selected_model in [
Text2ImgModels.dall_e.name,
Text2ImgModels.jack_qiao,
TextToImageModels.dall_e.name,
TextToImageModels.jack_qiao,
]:
return
gui.slider(
Expand Down Expand Up @@ -435,8 +434,8 @@ def instruct_pix2pix_settings():

def prompt_strength_setting(selected_model: str = None):
if selected_model in [
Img2ImgModels.dall_e.name,
Img2ImgModels.instruct_pix2pix.name,
ImageToImageModels.dall_e.name,
ImageToImageModels.instruct_pix2pix.name,
]:
return

Expand All @@ -458,7 +457,7 @@ def prompt_strength_setting(selected_model: str = None):


def negative_prompt_setting(selected_model: str = None):
if selected_model in [Text2ImgModels.dall_e.name]:
if selected_model in [TextToImageModels.dall_e.name]:
return

gui.text_area(
Expand Down
7 changes: 5 additions & 2 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from daras_ai.image_input import gs_url_to_uri, bytes_to_cv2_img, cv2_img_to_bytes
from daras_ai_v2.asr import get_google_auth_session
from daras_ai_v2.custom_enum import GooeyEnum
from daras_ai_v2.exceptions import raise_for_status, UserError
from daras_ai_v2.gpu_server import call_celery_task
from daras_ai_v2.text_splitter import (
Expand Down Expand Up @@ -72,7 +73,7 @@ class LLMSpec(typing.NamedTuple):
supports_json: bool = False


class LargeLanguageModels(Enum):
class LargeLanguageModels(GooeyEnum):
# https://platform.openai.com/docs/models/gpt-4o
gpt_4_o = LLMSpec(
label="GPT-4o (openai)",
Expand Down Expand Up @@ -474,7 +475,9 @@ def get_entry_text(entry: ConversationEntry) -> str:
)


ResponseFormatType = typing.Literal["text", "json_object"]
class ResponseFormatType(str, GooeyEnum):
text = "text"
json_object = "json_object"


def run_language_model(
Expand Down
Loading
Loading