From 507863471d94cc8e9acf8e0eac2873da2304c8c9 Mon Sep 17 00:00:00 2001 From: Morgan Fainberg Date: Tue, 15 Oct 2024 12:50:58 -0700 Subject: [PATCH] Deduplicate choices Ensure choices are not duplicated, as that makes StrEnum impossible to resolve. Each element should only exist once. The deduplicate must be done at the schema level and therefore a decorator is utilized to mutate duplicated choices passed to Input(). --- python/cog/types.py | 23 +++++++++++++++++++ python/tests/server/fixtures/input_choices.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/python/cog/types.py b/python/cog/types.py index f4110e68ca..732bf9d818 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -1,3 +1,4 @@ +import inspect import io import mimetypes import os @@ -53,6 +54,28 @@ class CogBuildConfig(TypedDict, total=False): # pylint: disable=too-many-ancest run: Optional[Union[List[str], List[Dict[str, Any]]]] +# The following decorator is used to mutate the definition of choices in the +# case that the value(s) are duplicated. This results in the inability to +# create the enum as the values are not unique. This is generally a hack to +# work around previously created invalid schemas. +def _deduplicate_choices(func: Any) -> Any: + def wrapper(*args: Any, **kwargs: Any) -> Any: + sig = inspect.signature(func) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + if ( + "choices" in bound_args.arguments + and bound_args.arguments["choices"] is not None + ): + bound_args.arguments["choices"] = list(set(bound_args.arguments["choices"])) + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapper + + +@_deduplicate_choices def Input( # pylint: disable=invalid-name, too-many-arguments default: Any = ..., description: str = None, diff --git a/python/tests/server/fixtures/input_choices.py b/python/tests/server/fixtures/input_choices.py index 659ee20e3f..28f0a8d99f 100644 --- a/python/tests/server/fixtures/input_choices.py +++ b/python/tests/server/fixtures/input_choices.py @@ -2,6 +2,6 @@ class Predictor(BasePredictor): - def predict(self, text: str = Input(choices=["foo", "bar"])) -> str: + def predict(self, text: str = Input(choices=["foo", "bar", "foo"])) -> str: assert type(text) == str return text