Skip to content

Commit

Permalink
feat: parameters can be described as Pydantic model for rich schema (#…
Browse files Browse the repository at this point in the history
…6001)

Signed-off-by: Joan Fontanals Martinez <[email protected]>
Co-authored-by: Jina Dev Bot <[email protected]>
  • Loading branch information
JoanFM and jina-bot authored Aug 1, 2023
1 parent 31fd261 commit 9c0378a
Show file tree
Hide file tree
Showing 15 changed files with 508 additions and 119 deletions.
1 change: 1 addition & 0 deletions .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ jobs:
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_v2.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/deployment_http_composite
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py
echo "flag it as jina for codeoverage"
echo "codecov_flag=jina" >> $GITHUB_OUTPUT
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ jobs:
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_v2.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/deployment_http_composite
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py
pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py
echo "flag it as jina for codeoverage"
echo "codecov_flag=jina" >> $GITHUB_OUTPUT
Expand Down
18 changes: 9 additions & 9 deletions jina/clients/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jina.importer import ImportExtensions

if TYPE_CHECKING: # pragma: no cover

from pydantic import BaseModel
from jina.clients.base import CallbackFnType, InputType
from jina.types.request.data import Response

Expand Down Expand Up @@ -344,7 +344,7 @@ def post(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
parameters: Optional[Dict] = None,
parameters: Union[Dict, 'BaseModel', None] = None,
target_executor: Optional[str] = None,
request_size: int = 100,
show_progress: bool = False,
Expand All @@ -362,12 +362,12 @@ def post(
) -> Optional[Union['DocumentArray', List['Response']]]:
"""Post a general data request to the Flow.
:param inputs: input data which can be an Iterable, a function which returns an Iterable, or a single Document.
:param inputs: input data which can be a DocList, a BaseDoc, an Iterable, a function which returns an Iterable.
:param on: the endpoint which is invoked. All the functions in the executors decorated by `@requests(on=...)` with the same endpoint are invoked.
:param on_done: the function to be called when the :class:`Request` object is resolved.
:param on_error: the function to be called when the :class:`Request` object is rejected.
:param on_always: the function to be called when the :class:`Request` object is either resolved or rejected.
:param parameters: the kwargs that will be sent to the executor
:param parameters: the parameters that will be sent to the executor, this can be a Dict or a Pydantic model
:param target_executor: a regex string. Only matching Executors will process the request.
:param request_size: the number of Documents per request. <=0 means all inputs in one request.
:param show_progress: if set, client will show a progress bar on receiving every request.
Expand All @@ -380,7 +380,7 @@ def post(
:param results_in_order: return the results in the same order as the inputs
:param stream: Applicable only to grpc client. If True, the requests are sent to the target using the gRPC streaming interface otherwise the gRPC unary interface will be used. The value is True by default.
:param prefetch: How many Requests are processed from the Client at the same time. If not provided then Gateway prefetch value will be used.
:param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`.
:param return_type: the DocList or BaseDoc type to be returned. By default, it is `DocumentArray`.
:param kwargs: additional parameters
:return: None or DocumentArray containing all response Documents
Expand Down Expand Up @@ -458,7 +458,7 @@ async def post(
on_done: Optional['CallbackFnType'] = None,
on_error: Optional['CallbackFnType'] = None,
on_always: Optional['CallbackFnType'] = None,
parameters: Optional[Dict] = None,
parameters: Union[Dict, 'BaseModel', None] = None,
target_executor: Optional[str] = None,
request_size: int = 100,
show_progress: bool = False,
Expand All @@ -476,12 +476,12 @@ async def post(
) -> AsyncGenerator[None, Union['DocumentArray', 'Response']]:
"""Async Post a general data request to the Flow.
:param inputs: input data which can be an Iterable, a function which returns an Iterable, or a single Document.
:param inputs: input data which can be a DocList, a BaseDoc, an Iterable, a function which returns an Iterable.
:param on: the endpoint which is invoked. All the functions in the executors decorated by `@requests(on=...)` with the same endpoint are invoked.
:param on_done: the function to be called when the :class:`Request` object is resolved.
:param on_error: the function to be called when the :class:`Request` object is rejected.
:param on_always: the function to be called when the :class:`Request` object is either resolved or rejected.
:param parameters: the kwargs that will be sent to the executor
:param parameters: the parameters that will be sent to the executor, this can be a Dict or a Pydantic model
:param target_executor: a regex string. Only matching Executors will process the request.
:param request_size: the number of Documents per request. <=0 means all inputs in one request.
:param show_progress: if set, client will show a progress bar on receiving every request.
Expand All @@ -494,7 +494,7 @@ async def post(
:param results_in_order: return the results in the same order as the inputs
:param stream: Applicable only to grpc client. If True, the requests are sent to the target using the gRPC streaming interface otherwise the gRPC unary interface will be used. The value is True by default.
:param prefetch: How many Requests are processed from the Client at the same time. If not provided then Gateway prefetch value will be used.
:param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`.
:param return_type: the DocList or BaseDoc type to be returned. By default, it is `DocumentArray`.
:param kwargs: additional parameters, can be used to pass metadata or authentication information in the server call
:yield: Response object
Expand Down
4 changes: 3 additions & 1 deletion jina/clients/request/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,19 @@
from jina._docarray.document import DocumentSourceType
from jina._docarray.document.mixins.content import DocumentContentType
from jina.types.request import Request
from docarray import DocList, BaseDoc

SingletonDataType = Union[
DocumentContentType,
DocumentSourceType,
Document,
BaseDoc,
Tuple[DocumentContentType, DocumentContentType],
Tuple[DocumentSourceType, DocumentSourceType],
]

GeneratorSourceType = Union[
Document, Iterable[SingletonDataType], AsyncIterable[SingletonDataType]
Document, Iterable[SingletonDataType], AsyncIterable[SingletonDataType], DocList
]


Expand Down
176 changes: 137 additions & 39 deletions jina/serve/executors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
ArgNamespace,
T,
get_or_reuse_loop,
is_generator,
iscoroutinefunction,
typename,
)
Expand Down Expand Up @@ -63,6 +62,56 @@
__all__ = ['BaseExecutor', __dry_run_endpoint__]


def is_pydantic_model(annotation: Type) -> bool:
"""Method to detect if parameter annotation corresponds to a Pydantic model
:param annotation: The annotation from which to extract PydantiModel.
:return: boolean indicating if a Pydantic model is inside the annotation
"""
from pydantic import BaseModel
from typing import get_args, get_origin

origin = get_origin(annotation) or annotation
args = get_args(annotation)

# If the origin itself is a Pydantic model, return True
if isinstance(origin, type) and issubclass(origin, BaseModel):
return True

# Check the arguments (for the actual types inside Union, Optional, etc.)
if args:
return any(is_pydantic_model(arg) for arg in args)

return False


def get_inner_pydantic_model(annotation: Type) -> bool:
"""Method to get the Pydantic model corresponding, in case there is optional or something
:param annotation: The annotation from which to extract PydantiModel.
:return: The inner Pydantic model expected
"""
try:
from pydantic import BaseModel
from typing import Type, Optional, get_args, get_origin, Union

origin = get_origin(annotation) or annotation
args = get_args(annotation)

# If the origin itself is a Pydantic model, return True
if isinstance(origin, type) and issubclass(origin, BaseModel):
return origin

# Check the arguments (for the actual types inside Union, Optional, etc.)
if args:
for arg in args:
if is_pydantic_model(arg):
return arg
except:
pass
return None


class ExecutorType(type(JAMLCompatible), type):
"""The class of Executor type, which is the metaclass of :class:`BaseExecutor`."""

Expand Down Expand Up @@ -114,9 +163,11 @@ def register_class(cls):

class _FunctionWithSchema(NamedTuple):
fn: Callable
is_generator: False
is_batch_docs: False
is_generator: bool
is_batch_docs: bool
is_singleton_doc: False
parameters_is_pydantic_model: bool
parameters_model: Type
request_schema: Type[DocumentArray] = DocumentArray
response_schema: Type[DocumentArray] = DocumentArray

Expand Down Expand Up @@ -171,7 +222,6 @@ def validate(self):

@staticmethod
def get_function_with_schema(fn: Callable) -> T:

# if it's not a generator function, infer the type annotation from the docs parameter
# otherwise, infer from the doc parameter (since generator endpoints expect only 1 document as input)
is_generator = getattr(fn, '__is_generator__', False)
Expand All @@ -188,6 +238,16 @@ def get_function_with_schema(fn: Callable) -> T:
docs_annotation = fn.__annotations__.get(
'docs', fn.__annotations__.get('doc', None)
)
parameters_model = (
fn.__annotations__.get('parameters', None) if docarray_v2 else None
)
parameters_is_pydantic_model = False
if parameters_model is not None and docarray_v2:
from pydantic import BaseModel

parameters_is_pydantic_model = is_pydantic_model(parameters_model)
parameters_model = get_inner_pydantic_model(parameters_model)

if docarray_v2:
from docarray import BaseDoc, DocList

Expand Down Expand Up @@ -256,6 +316,8 @@ def get_function_with_schema(fn: Callable) -> T:
is_generator=is_generator,
is_singleton_doc=is_singleton_doc,
is_batch_docs=is_batch_docs,
parameters_model=parameters_model,
parameters_is_pydantic_model=parameters_is_pydantic_model,
request_schema=request_schema,
response_schema=response_schema,
)
Expand Down Expand Up @@ -370,6 +432,7 @@ def _get_endpoint_models_dict(self):
_is_generator = function_with_schema.is_generator
_is_singleton_doc = function_with_schema.is_singleton_doc
_is_batch_docs = function_with_schema.is_batch_docs
_parameters_model = function_with_schema.parameters_model
if docarray_v2:
# if the endpoint is not a generator endpoint, then the request schema is a DocumentArray and we need
# to get the doc_type from the schema
Expand Down Expand Up @@ -403,6 +466,12 @@ def _get_endpoint_models_dict(self):
},
'is_generator': _is_generator,
'is_singleton_doc': _is_singleton_doc,
'parameters': {
'name': _parameters_model.__name__
if _parameters_model is not None
else None,
'model': _parameters_model,
},
}
return endpoint_models

Expand Down Expand Up @@ -626,52 +695,81 @@ async def __acall__(self, req_endpoint: str, **kwargs):
async def __acall_endpoint__(
self, req_endpoint, tracing_context: Optional['Context'], **kwargs
):

# Decorator to make sure that `parameters` are passed as PydanticModels if needed
def parameters_as_pydantic_models_decorator(func, parameters_pydantic_model):
@functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata
def wrapper(*args, **kwargs):
parameters = kwargs.get('parameters', None)
if parameters is not None:
parameters = parameters_pydantic_model(**parameters)
kwargs['parameters'] = parameters
result = func(*args, **kwargs)
return result

return wrapper

# Decorator to make sure that `docs` are fed one by one to method using singleton document serving
def loop_docs_decorator(func):
@functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata
def wrapper(*args, **kwargs):
docs = kwargs.pop('docs')
if docarray_v2:
from docarray import DocList

ret = DocList[response_schema]()
else:
ret = DocumentArray()
for doc in docs:
f_ret = func(*args, doc=doc, **kwargs)
if f_ret is None:
ret.append(doc) # this means change in place
else:
ret.append(f_ret)
return ret

return wrapper

def async_loop_docs_decorator(func):
@functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata
async def wrapper(*args, **kwargs):
docs = kwargs.pop('docs')
if docarray_v2:
from docarray import DocList

ret = DocList[response_schema]()
else:
ret = DocumentArray()
for doc in docs:
f_ret = await original_func(*args, doc=doc, **kwargs)
if f_ret is None:
ret.append(doc) # this means change in place
else:
ret.append(f_ret)
return ret

return wrapper

fn_info = self.requests[req_endpoint]
original_func = fn_info.fn
is_generator = fn_info.is_generator
is_batch_docs = fn_info.is_batch_docs
response_schema = fn_info.response_schema
parameters_model = fn_info.parameters_model
is_parameters_pydantic_model = fn_info.parameters_is_pydantic_model

func = original_func
if is_generator or is_batch_docs:
func = original_func
pass
elif kwargs.get('docs', None) is not None:
# This means I need to pass every doc (most likely 1, but potentially more)
if iscoroutinefunction(original_func):

async def loop_func(*args, **kwargs):
docs = kwargs.pop('docs')
if docarray_v2:
from docarray import DocList

ret = DocList[response_schema]()
else:
ret = DocumentArray()
for doc in docs:
f_ret = await original_func(*args, doc=doc, **kwargs)
if f_ret is None:
ret.append(doc) # this means change in place
else:
ret.append(f_ret)
return ret

func = async_loop_docs_decorator(original_func)
else:
func = loop_docs_decorator(original_func)

def loop_func(*args, **kwargs):
docs = kwargs.pop('docs')
if docarray_v2:
from docarray import DocList

ret = DocList[response_schema]()
else:
ret = DocumentArray()
for doc in docs:
f_ret = original_func(*args, doc=doc, **kwargs)
if f_ret is None:
ret.append(doc) # this means change in place
else:
ret.append(f_ret)
return ret

func = loop_func
if is_parameters_pydantic_model:
func = parameters_as_pydantic_models_decorator(func, parameters_model)

async def exec_func(
summary, histogram, histogram_metric_labels, tracing_context
Expand Down
9 changes: 8 additions & 1 deletion jina/serve/executors/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,14 @@ def _inject_owner_attrs(
)

fn_with_schema = _FunctionWithSchema(
fn=fn_with_schema.fn, is_generator=fn_with_schema.is_generator, is_singleton_doc=fn_with_schema.is_singleton_doc, is_batch_docs=fn_with_schema.is_batch_docs, request_schema=request_schema_arg, response_schema=response_schema_arg
fn=fn_with_schema.fn,
is_generator=fn_with_schema.is_generator,
is_singleton_doc=fn_with_schema.is_singleton_doc,
is_batch_docs=fn_with_schema.is_batch_docs,
parameters_is_pydantic_model=fn_with_schema.parameters_is_pydantic_model,
parameters_model=fn_with_schema.parameters_model,
request_schema=request_schema_arg,
response_schema=response_schema_arg
)
fn_with_schema.validate()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]':
from docarray.base_doc import AnyDoc

prev_doc_array_cls = request.data.document_array_cls
print(f' hey here I am JOAN')
request.data.document_array_cls = DocList[AnyDoc]
request_doc_ids = request.data.docs.id
request.data._loaded_doc_array = None
Expand Down
Loading

0 comments on commit 9c0378a

Please sign in to comment.