Skip to content

Commit

Permalink
Deduplicate choices
Browse files Browse the repository at this point in the history
Ensure choices are not duplicated, as that makes StrEnum impossible to
resolve. Each element should only exist once.

The deduplicate must be done at the schema level and therefore a
decorator is utilized to mutate duplicated choices passed to Input().
  • Loading branch information
tempusfrangit committed Oct 16, 2024
1 parent 37c141e commit 46c61f0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
23 changes: 23 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import io
import mimetypes
import os
Expand Down Expand Up @@ -53,6 +54,28 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest
run: Optional[Union[List[str], List[Dict[str, Any]]]]


# The following decorator is used to mutate the definition of choices in the
# case that the value(s) are duplicated. This results in the inability to
# create the enum as the values are not unique. This is generally a hack to
# work around previously created invalid schemas.
def _deduplicate_choices(func: Any) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> Any:
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

if (
"choices" in bound_args.arguments
and bound_args.arguments["choices"] is not None
):
bound_args.arguments["choices"] = list(set(bound_args.arguments["choices"]))

return func(*bound_args.args, **bound_args.kwargs)

return wrapper


@_deduplicate_choices
def Input( # pylint: disable=invalid-name, too-many-arguments
default: Any = ...,
description: str = None,
Expand Down
2 changes: 1 addition & 1 deletion python/tests/server/fixtures/input_choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@


class Predictor(BasePredictor):
def predict(self, text: str = Input(choices=["foo", "bar"])) -> str:
def predict(self, text: str = Input(choices=["foo", "bar", "foo"])) -> str:
assert type(text) == str
return text

0 comments on commit 46c61f0

Please sign in to comment.