diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index e3c7a7f892c88..0901924a5fa0b 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -144,6 +144,7 @@ jobs: pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_v2.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/deployment_http_composite pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py echo "flag it as jina for codeoverage" echo "codecov_flag=jina" >> $GITHUB_OUTPUT diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 94812f311af48..93db99fad45f0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -469,6 +469,7 @@ jobs: pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_v2.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/deployment_http_composite pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py echo "flag it as jina for codeoverage" echo "codecov_flag=jina" >> $GITHUB_OUTPUT diff --git a/jina/clients/mixin.py b/jina/clients/mixin.py index 3e49d99b56433..32bdc093f4ebd 100644 --- a/jina/clients/mixin.py +++ b/jina/clients/mixin.py @@ -12,7 +12,7 @@ from jina.importer import ImportExtensions if TYPE_CHECKING: # pragma: no cover - + from pydantic import BaseModel from jina.clients.base import CallbackFnType, InputType from jina.types.request.data import Response @@ -344,7 +344,7 @@ def post( on_done: Optional['CallbackFnType'] = None, on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, - parameters: Optional[Dict] = None, + parameters: Union[Dict, 'BaseModel', None] = None, target_executor: Optional[str] = None, request_size: int = 100, show_progress: bool = False, @@ -362,12 +362,12 @@ def post( ) -> Optional[Union['DocumentArray', List['Response']]]: """Post a general data request to the Flow. - :param inputs: input data which can be an Iterable, a function which returns an Iterable, or a single Document. + :param inputs: input data which can be a DocList, a BaseDoc, an Iterable, a function which returns an Iterable. :param on: the endpoint which is invoked. All the functions in the executors decorated by `@requests(on=...)` with the same endpoint are invoked. :param on_done: the function to be called when the :class:`Request` object is resolved. :param on_error: the function to be called when the :class:`Request` object is rejected. :param on_always: the function to be called when the :class:`Request` object is either resolved or rejected. - :param parameters: the kwargs that will be sent to the executor + :param parameters: the parameters that will be sent to the executor, this can be a Dict or a Pydantic model :param target_executor: a regex string. Only matching Executors will process the request. :param request_size: the number of Documents per request. <=0 means all inputs in one request. :param show_progress: if set, client will show a progress bar on receiving every request. @@ -380,7 +380,7 @@ def post( :param results_in_order: return the results in the same order as the inputs :param stream: Applicable only to grpc client. If True, the requests are sent to the target using the gRPC streaming interface otherwise the gRPC unary interface will be used. The value is True by default. :param prefetch: How many Requests are processed from the Client at the same time. If not provided then Gateway prefetch value will be used. - :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. + :param return_type: the DocList or BaseDoc type to be returned. By default, it is `DocumentArray`. :param kwargs: additional parameters :return: None or DocumentArray containing all response Documents @@ -458,7 +458,7 @@ async def post( on_done: Optional['CallbackFnType'] = None, on_error: Optional['CallbackFnType'] = None, on_always: Optional['CallbackFnType'] = None, - parameters: Optional[Dict] = None, + parameters: Union[Dict, 'BaseModel', None] = None, target_executor: Optional[str] = None, request_size: int = 100, show_progress: bool = False, @@ -476,12 +476,12 @@ async def post( ) -> AsyncGenerator[None, Union['DocumentArray', 'Response']]: """Async Post a general data request to the Flow. - :param inputs: input data which can be an Iterable, a function which returns an Iterable, or a single Document. + :param inputs: input data which can be a DocList, a BaseDoc, an Iterable, a function which returns an Iterable. :param on: the endpoint which is invoked. All the functions in the executors decorated by `@requests(on=...)` with the same endpoint are invoked. :param on_done: the function to be called when the :class:`Request` object is resolved. :param on_error: the function to be called when the :class:`Request` object is rejected. :param on_always: the function to be called when the :class:`Request` object is either resolved or rejected. - :param parameters: the kwargs that will be sent to the executor + :param parameters: the parameters that will be sent to the executor, this can be a Dict or a Pydantic model :param target_executor: a regex string. Only matching Executors will process the request. :param request_size: the number of Documents per request. <=0 means all inputs in one request. :param show_progress: if set, client will show a progress bar on receiving every request. @@ -494,7 +494,7 @@ async def post( :param results_in_order: return the results in the same order as the inputs :param stream: Applicable only to grpc client. If True, the requests are sent to the target using the gRPC streaming interface otherwise the gRPC unary interface will be used. The value is True by default. :param prefetch: How many Requests are processed from the Client at the same time. If not provided then Gateway prefetch value will be used. - :param return_type: the DocumentArray type to be returned. By default, it is `DocumentArray`. + :param return_type: the DocList or BaseDoc type to be returned. By default, it is `DocumentArray`. :param kwargs: additional parameters, can be used to pass metadata or authentication information in the server call :yield: Response object diff --git a/jina/clients/request/__init__.py b/jina/clients/request/__init__.py index a82f1283c4474..d89960d0e19ca 100644 --- a/jina/clients/request/__init__.py +++ b/jina/clients/request/__init__.py @@ -22,17 +22,19 @@ from jina._docarray.document import DocumentSourceType from jina._docarray.document.mixins.content import DocumentContentType from jina.types.request import Request + from docarray import DocList, BaseDoc SingletonDataType = Union[ DocumentContentType, DocumentSourceType, Document, + BaseDoc, Tuple[DocumentContentType, DocumentContentType], Tuple[DocumentSourceType, DocumentSourceType], ] GeneratorSourceType = Union[ - Document, Iterable[SingletonDataType], AsyncIterable[SingletonDataType] + Document, Iterable[SingletonDataType], AsyncIterable[SingletonDataType], DocList ] diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index a2bbe0788e59d..2f4117c9e772d 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -33,7 +33,6 @@ ArgNamespace, T, get_or_reuse_loop, - is_generator, iscoroutinefunction, typename, ) @@ -63,6 +62,56 @@ __all__ = ['BaseExecutor', __dry_run_endpoint__] +def is_pydantic_model(annotation: Type) -> bool: + """Method to detect if parameter annotation corresponds to a Pydantic model + + :param annotation: The annotation from which to extract PydantiModel. + :return: boolean indicating if a Pydantic model is inside the annotation + """ + from pydantic import BaseModel + from typing import get_args, get_origin + + origin = get_origin(annotation) or annotation + args = get_args(annotation) + + # If the origin itself is a Pydantic model, return True + if isinstance(origin, type) and issubclass(origin, BaseModel): + return True + + # Check the arguments (for the actual types inside Union, Optional, etc.) + if args: + return any(is_pydantic_model(arg) for arg in args) + + return False + + +def get_inner_pydantic_model(annotation: Type) -> bool: + """Method to get the Pydantic model corresponding, in case there is optional or something + + :param annotation: The annotation from which to extract PydantiModel. + :return: The inner Pydantic model expected + """ + try: + from pydantic import BaseModel + from typing import Type, Optional, get_args, get_origin, Union + + origin = get_origin(annotation) or annotation + args = get_args(annotation) + + # If the origin itself is a Pydantic model, return True + if isinstance(origin, type) and issubclass(origin, BaseModel): + return origin + + # Check the arguments (for the actual types inside Union, Optional, etc.) + if args: + for arg in args: + if is_pydantic_model(arg): + return arg + except: + pass + return None + + class ExecutorType(type(JAMLCompatible), type): """The class of Executor type, which is the metaclass of :class:`BaseExecutor`.""" @@ -114,9 +163,11 @@ def register_class(cls): class _FunctionWithSchema(NamedTuple): fn: Callable - is_generator: False - is_batch_docs: False + is_generator: bool + is_batch_docs: bool is_singleton_doc: False + parameters_is_pydantic_model: bool + parameters_model: Type request_schema: Type[DocumentArray] = DocumentArray response_schema: Type[DocumentArray] = DocumentArray @@ -171,7 +222,6 @@ def validate(self): @staticmethod def get_function_with_schema(fn: Callable) -> T: - # if it's not a generator function, infer the type annotation from the docs parameter # otherwise, infer from the doc parameter (since generator endpoints expect only 1 document as input) is_generator = getattr(fn, '__is_generator__', False) @@ -188,6 +238,16 @@ def get_function_with_schema(fn: Callable) -> T: docs_annotation = fn.__annotations__.get( 'docs', fn.__annotations__.get('doc', None) ) + parameters_model = ( + fn.__annotations__.get('parameters', None) if docarray_v2 else None + ) + parameters_is_pydantic_model = False + if parameters_model is not None and docarray_v2: + from pydantic import BaseModel + + parameters_is_pydantic_model = is_pydantic_model(parameters_model) + parameters_model = get_inner_pydantic_model(parameters_model) + if docarray_v2: from docarray import BaseDoc, DocList @@ -256,6 +316,8 @@ def get_function_with_schema(fn: Callable) -> T: is_generator=is_generator, is_singleton_doc=is_singleton_doc, is_batch_docs=is_batch_docs, + parameters_model=parameters_model, + parameters_is_pydantic_model=parameters_is_pydantic_model, request_schema=request_schema, response_schema=response_schema, ) @@ -370,6 +432,7 @@ def _get_endpoint_models_dict(self): _is_generator = function_with_schema.is_generator _is_singleton_doc = function_with_schema.is_singleton_doc _is_batch_docs = function_with_schema.is_batch_docs + _parameters_model = function_with_schema.parameters_model if docarray_v2: # if the endpoint is not a generator endpoint, then the request schema is a DocumentArray and we need # to get the doc_type from the schema @@ -403,6 +466,12 @@ def _get_endpoint_models_dict(self): }, 'is_generator': _is_generator, 'is_singleton_doc': _is_singleton_doc, + 'parameters': { + 'name': _parameters_model.__name__ + if _parameters_model is not None + else None, + 'model': _parameters_model, + }, } return endpoint_models @@ -626,52 +695,81 @@ async def __acall__(self, req_endpoint: str, **kwargs): async def __acall_endpoint__( self, req_endpoint, tracing_context: Optional['Context'], **kwargs ): + + # Decorator to make sure that `parameters` are passed as PydanticModels if needed + def parameters_as_pydantic_models_decorator(func, parameters_pydantic_model): + @functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata + def wrapper(*args, **kwargs): + parameters = kwargs.get('parameters', None) + if parameters is not None: + parameters = parameters_pydantic_model(**parameters) + kwargs['parameters'] = parameters + result = func(*args, **kwargs) + return result + + return wrapper + + # Decorator to make sure that `docs` are fed one by one to method using singleton document serving + def loop_docs_decorator(func): + @functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata + def wrapper(*args, **kwargs): + docs = kwargs.pop('docs') + if docarray_v2: + from docarray import DocList + + ret = DocList[response_schema]() + else: + ret = DocumentArray() + for doc in docs: + f_ret = func(*args, doc=doc, **kwargs) + if f_ret is None: + ret.append(doc) # this means change in place + else: + ret.append(f_ret) + return ret + + return wrapper + + def async_loop_docs_decorator(func): + @functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata + async def wrapper(*args, **kwargs): + docs = kwargs.pop('docs') + if docarray_v2: + from docarray import DocList + + ret = DocList[response_schema]() + else: + ret = DocumentArray() + for doc in docs: + f_ret = await original_func(*args, doc=doc, **kwargs) + if f_ret is None: + ret.append(doc) # this means change in place + else: + ret.append(f_ret) + return ret + + return wrapper + fn_info = self.requests[req_endpoint] original_func = fn_info.fn is_generator = fn_info.is_generator is_batch_docs = fn_info.is_batch_docs response_schema = fn_info.response_schema + parameters_model = fn_info.parameters_model + is_parameters_pydantic_model = fn_info.parameters_is_pydantic_model + + func = original_func if is_generator or is_batch_docs: - func = original_func + pass elif kwargs.get('docs', None) is not None: # This means I need to pass every doc (most likely 1, but potentially more) if iscoroutinefunction(original_func): - - async def loop_func(*args, **kwargs): - docs = kwargs.pop('docs') - if docarray_v2: - from docarray import DocList - - ret = DocList[response_schema]() - else: - ret = DocumentArray() - for doc in docs: - f_ret = await original_func(*args, doc=doc, **kwargs) - if f_ret is None: - ret.append(doc) # this means change in place - else: - ret.append(f_ret) - return ret - + func = async_loop_docs_decorator(original_func) else: + func = loop_docs_decorator(original_func) - def loop_func(*args, **kwargs): - docs = kwargs.pop('docs') - if docarray_v2: - from docarray import DocList - - ret = DocList[response_schema]() - else: - ret = DocumentArray() - for doc in docs: - f_ret = original_func(*args, doc=doc, **kwargs) - if f_ret is None: - ret.append(doc) # this means change in place - else: - ret.append(f_ret) - return ret - - func = loop_func + if is_parameters_pydantic_model: + func = parameters_as_pydantic_models_decorator(func, parameters_model) async def exec_func( summary, histogram, histogram_metric_labels, tracing_context diff --git a/jina/serve/executors/decorators.py b/jina/serve/executors/decorators.py index 86f8cb5ec0ec4..93b9e443efda3 100644 --- a/jina/serve/executors/decorators.py +++ b/jina/serve/executors/decorators.py @@ -361,7 +361,14 @@ def _inject_owner_attrs( ) fn_with_schema = _FunctionWithSchema( - fn=fn_with_schema.fn, is_generator=fn_with_schema.is_generator, is_singleton_doc=fn_with_schema.is_singleton_doc, is_batch_docs=fn_with_schema.is_batch_docs, request_schema=request_schema_arg, response_schema=response_schema_arg + fn=fn_with_schema.fn, + is_generator=fn_with_schema.is_generator, + is_singleton_doc=fn_with_schema.is_singleton_doc, + is_batch_docs=fn_with_schema.is_batch_docs, + parameters_is_pydantic_model=fn_with_schema.parameters_is_pydantic_model, + parameters_model=fn_with_schema.parameters_model, + request_schema=request_schema_arg, + response_schema=response_schema_arg ) fn_with_schema.validate() diff --git a/jina/serve/runtimes/gateway/async_request_response_handling.py b/jina/serve/runtimes/gateway/async_request_response_handling.py index e662caf3a007e..b7025b58ba3d6 100644 --- a/jina/serve/runtimes/gateway/async_request_response_handling.py +++ b/jina/serve/runtimes/gateway/async_request_response_handling.py @@ -120,6 +120,7 @@ def _handle_request(request: 'Request') -> 'Tuple[Future, Optional[Future]]': from docarray.base_doc import AnyDoc prev_doc_array_cls = request.data.document_array_cls + print(f' hey here I am JOAN') request.data.document_array_cls = DocList[AnyDoc] request_doc_ids = request.data.docs.id request.data._loaded_doc_array = None diff --git a/jina/serve/runtimes/gateway/graph/topology_graph.py b/jina/serve/runtimes/gateway/graph/topology_graph.py index 4c1dd40482009..6222fae51639e 100644 --- a/jina/serve/runtimes/gateway/graph/topology_graph.py +++ b/jina/serve/runtimes/gateway/graph/topology_graph.py @@ -190,6 +190,7 @@ async def task(): input_model = models_created_by_name[input_model_name] models_schema_list.append(input_model_schema) models_list.append(input_model) + output_model_name = inner_dict['output']['name'] output_model_schema = inner_dict['output']['model'] if output_model_schema in models_schema_list: @@ -206,13 +207,35 @@ async def task(): models_created_by_name[output_model_name] = output_model output_model = models_created_by_name[output_model_name] models_schema_list.append(output_model) - models_list.append(input_model) + models_list.append(output_model) + + parameters_model_name = inner_dict['parameters']['name'] + parameters_model_schema = inner_dict['parameters']['model'] + if parameters_model_schema is not None: + if parameters_model_schema in models_schema_list: + parameters_model = models_list[ + models_schema_list.index(parameters_model_schema)] + models_created_by_name[parameters_model_name] = parameters_model + else: + if parameters_model_name not in models_created_by_name: + from pydantic import BaseModel + parameters_model = _create_pydantic_model_from_schema(parameters_model_schema, + parameters_model_name, + {}, + base_class=BaseModel) + models_created_by_name[parameters_model_name] = parameters_model + parameters_model = models_created_by_name[parameters_model_name] + models_schema_list.append(parameters_model_schema) + models_list.append(parameters_model) + else: + parameters_model = None self._pydantic_models_by_endpoint[endpoint] = { 'input': input_model, 'output': output_model, - 'is_generator': inner_dict['is_generator'], - 'is_singleton_doc': inner_dict['is_singleton_doc'] + 'is_generator': inner_dict['is_generator'], + 'is_singleton_doc': inner_dict['is_singleton_doc'], + 'parameters': parameters_model } self._endpoints_proto = endpoints_proto else: @@ -352,6 +375,7 @@ def _get_input_output_model_for_endpoint(self, previous_output, previous_is_generator, previous_is_singleton_doc, + previous_parameters, endpoint): if self._pydantic_models_by_endpoint is not None: @@ -375,6 +399,7 @@ def _get_input_output_model_for_endpoint(self, 'output': previous_output, 'is_generator': is_generator, 'is_singleton_doc': self._pydantic_models_by_endpoint[endpoint]['is_singleton_doc'], + 'parameters': self._pydantic_models_by_endpoint[endpoint]['parameters'], } else: return { @@ -382,13 +407,15 @@ def _get_input_output_model_for_endpoint(self, 'output': self._pydantic_models_by_endpoint[endpoint]['output'], 'is_generator': is_generator, 'is_singleton_doc': self._pydantic_models_by_endpoint[endpoint]['is_singleton_doc'], + 'parameters': self._pydantic_models_by_endpoint[endpoint]['parameters'], } else: return { 'input': previous_input, 'output': previous_output, 'is_generator': previous_is_generator, - 'is_singleton_doc': False + 'is_singleton_doc': False, + 'parameters': previous_parameters, } return None @@ -398,12 +425,14 @@ def _get_leaf_input_output_model( previous_output, previous_is_generator, previous_is_singleton_doc, + previous_parameters, endpoint: Optional[str] = None, ): new_map = self._get_input_output_model_for_endpoint(previous_input, previous_output, previous_is_generator, previous_is_singleton_doc, + previous_parameters, endpoint) if self.leaf: # I am like a leaf return list([new_map] if new_map is not None else []) # I am the last in the chain @@ -414,6 +443,7 @@ def _get_leaf_input_output_model( previous_output=new_map['output'] if new_map is not None else None, previous_is_generator=new_map['is_generator'] if new_map is not None else None, previous_is_singleton_doc=new_map['is_singleton_doc'] if new_map is not None else None, + previous_parameters=new_map['parameters'] if new_map is not None else None, endpoint=endpoint ) # We are interested in the last one, that will be the task that awaits all the previous @@ -559,12 +589,14 @@ def _get_leaf_input_output_model( previous_output, previous_is_generator, previous_is_singleton_doc, + previous_parameters, endpoint: Optional[str] = None, ): return [{'input': previous_input, 'output': previous_output, 'is_generator': previous_is_generator, - 'is_singleton_doc': previous_is_singleton_doc}] + 'is_singleton_doc': previous_is_singleton_doc, + 'parameters': previous_parameters}] def __init__( self, diff --git a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py index 019ecf8ed3bd5..de286e7aea9f8 100644 --- a/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py +++ b/jina/serve/runtimes/gateway/http_fastapi_app_docarrayv2.py @@ -248,13 +248,15 @@ async def event_generator(): input_doc_model = input_output_map['input'] output_doc_model = input_output_map['output'] is_generator = input_output_map['is_generator'] + parameters_model = input_output_map['parameters'] or Optional[Dict] + default_parameters = ... if input_output_map['parameters'] else None _config = inherit_config(InnerConfig, BaseDoc.__config__) endpoint_input_model = pydantic.create_model( f'{endpoint.strip("/")}_input_model', data=(Union[List[input_doc_model], input_doc_model], ...), - parameters=(Optional[Dict], None), + parameters=(parameters_model, default_parameters), header=(Optional[Header], None), __config__=_config, ) diff --git a/jina/serve/runtimes/helper.py b/jina/serve/runtimes/helper.py index 97e9350c21a47..20e7ae8278dd6 100644 --- a/jina/serve/runtimes/helper.py +++ b/jina/serve/runtimes/helper.py @@ -115,7 +115,7 @@ def _create_aux_model_doc_list_to_list(model): ) - def _get_field_from_type(field_schema, field_name, root_schema, cached_models, is_tensor=False, num_recursions=0): + def _get_field_from_type(field_schema, field_name, root_schema, cached_models, is_tensor=False, num_recursions=0, base_class=BaseDoc): field_type = field_schema.get('type', None) tensor_shape = field_schema.get('tensor/array shape', None) if 'anyOf' in field_schema: @@ -126,12 +126,13 @@ def _get_field_from_type(field_schema, field_name, root_schema, cached_models, i ref_name = obj_ref.split('/')[-1] any_of_types.append( _create_pydantic_model_from_schema(root_schema['definitions'][ref_name], ref_name, - cached_models=cached_models)) + cached_models=cached_models, base_class=base_class)) else: any_of_types.append(_get_field_from_type(any_of_schema, field_name, root_schema=root_schema, cached_models=cached_models, is_tensor=tensor_shape is not None, - num_recursions=0)) # No Union of Lists + num_recursions=0, + base_class=base_class)) # No Union of Lists ret = Union[tuple(any_of_types)] for rec in range(num_recursions): ret = List[ret] @@ -163,7 +164,7 @@ def _get_field_from_type(field_schema, field_name, root_schema, cached_models, i additional_props = field_schema['additionalProperties'] if additional_props.get('type') == 'object': ret = Dict[str, _create_pydantic_model_from_schema(additional_props, field_name, - cached_models=cached_models)] + cached_models=cached_models, base_class=base_class)] else: ret = Dict[str, Any] else: @@ -172,21 +173,21 @@ def _get_field_from_type(field_schema, field_name, root_schema, cached_models, i if obj_ref: ref_name = obj_ref.split('/')[-1] ret = _create_pydantic_model_from_schema(root_schema['definitions'][ref_name], ref_name, - cached_models=cached_models) + cached_models=cached_models, base_class=base_class) else: ret = Any else: # object reference in definitions if obj_ref: ref_name = obj_ref.split('/')[-1] ret = DocList[_create_pydantic_model_from_schema(root_schema['definitions'][ref_name], ref_name, - cached_models=cached_models)] + cached_models=cached_models, base_class=base_class)] else: ret = DocList[ - _create_pydantic_model_from_schema(field_schema, field_name, cached_models=cached_models)] + _create_pydantic_model_from_schema(field_schema, field_name, cached_models=cached_models, base_class=base_class)] elif field_type == 'array': ret = _get_field_from_type(field_schema=field_schema.get('items', {}), field_name=field_name, root_schema=root_schema, cached_models=cached_models, - is_tensor=tensor_shape is not None, num_recursions=num_recursions + 1) + is_tensor=tensor_shape is not None, num_recursions=num_recursions + 1, base_class=base_class) else: if num_recursions > 0: raise ValueError(f"Unknown array item type: {field_type} for field_name {field_name}") @@ -195,7 +196,7 @@ def _get_field_from_type(field_schema, field_name, root_schema, cached_models, i return ret - def _create_pydantic_model_from_schema(schema: Dict[str, any], model_name: str, cached_models: Dict) -> type: + def _create_pydantic_model_from_schema(schema: Dict[str, any], model_name: str, cached_models: Dict, base_class=BaseDoc) -> type: cached_models = cached_models if cached_models is not None else {} fields: Dict[str, Any] = {} if model_name in cached_models: @@ -208,10 +209,11 @@ def _create_pydantic_model_from_schema(schema: Dict[str, any], model_name: str, cached_models=cached_models, is_tensor=False, num_recursions=0, + base_class=base_class ) fields[field_name] = (field_type, FieldInfo(default=field_schema.pop('default', None), **field_schema)) - model = create_model(model_name, __base__=BaseDoc, **fields) + model = create_model(model_name, __base__=base_class, **fields) model.__config__.title = schema.get('title', model.__config__.title) for k in RESERVED_KEYS: diff --git a/jina/serve/runtimes/worker/http_fastapi_app.py b/jina/serve/runtimes/worker/http_fastapi_app.py index 1c9f9dbd8dbc1..fb3470425fcc4 100644 --- a/jina/serve/runtimes/worker/http_fastapi_app.py +++ b/jina/serve/runtimes/worker/http_fastapi_app.py @@ -98,11 +98,13 @@ async def post(body: input_model, response: Response): if not docarray_v2: req.data.docs = DocumentArray.from_pydantic_model(data) else: + req.document_array_cls = DocList[input_doc_model] req.data.docs = DocList[input_doc_list_model](data) else: if not docarray_v2: req.data.docs = DocumentArray([Document.from_pydantic_model(data)]) else: + req.document_array_cls = DocList[input_doc_model] req.data.docs = DocList[input_doc_list_model]([data]) if body.header is None: req.header.request_id = req.docs[0].id @@ -150,6 +152,8 @@ async def streaming_get(request: Request): input_doc_model = input_output_map['input']['model'] output_doc_model = input_output_map['output']['model'] is_generator = input_output_map['is_generator'] + parameters_model = input_output_map['parameters']['model'] or Optional[Dict] + default_parameters = ... if input_output_map['parameters']['model'] else None if docarray_v2: _config = inherit_config(InnerConfig, BaseDoc.__config__) @@ -159,7 +163,7 @@ async def streaming_get(request: Request): endpoint_input_model = pydantic.create_model( f'{endpoint.strip("/")}_input_model', data=(Union[List[input_doc_model], input_doc_model], ...), - parameters=(Optional[Dict], None), + parameters=(parameters_model, default_parameters), header=(Optional[Header], None), __config__=_config, ) diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 6b72eaf801740..d136361f50bd5 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -51,16 +51,16 @@ class WorkerRequestHandler: _KEY_RESULT = '__results__' def __init__( - self, - args: 'argparse.Namespace', - logger: 'JinaLogger', - metrics_registry: Optional['CollectorRegistry'] = None, - tracer_provider: Optional['trace.TracerProvider'] = None, - meter_provider: Optional['metrics.MeterProvider'] = None, - meter=None, - tracer=None, - deployment_name: str = '', - **kwargs, + self, + args: 'argparse.Namespace', + logger: 'JinaLogger', + metrics_registry: Optional['CollectorRegistry'] = None, + tracer_provider: Optional['trace.TracerProvider'] = None, + meter_provider: Optional['metrics.MeterProvider'] = None, + meter=None, + tracer=None, + deployment_name: str = '', + **kwargs, ): """Initialize private parameters and execute private loading functions. @@ -83,8 +83,8 @@ def __init__( self._is_closed = False if self.metrics_registry: with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary @@ -205,9 +205,9 @@ async def _hot_reload(self): watched_files.add(extra_python_file) with ImportExtensions( - required=True, - logger=self.logger, - help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install + required=True, + logger=self.logger, + help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install watchfiles''', ): from watchfiles import awatch @@ -274,16 +274,16 @@ def _init_batchqueue_dict(self): } def _init_monitoring( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - meter: Optional['metrics.Meter'] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + meter: Optional['metrics.Meter'] = None, ): if metrics_registry: with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary @@ -339,10 +339,10 @@ def _init_monitoring( self._sent_response_size_histogram = None def _load_executor( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - tracer_provider: Optional['trace.TracerProvider'] = None, - meter_provider: Optional['metrics.MeterProvider'] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + tracer_provider: Optional['trace.TracerProvider'] = None, + meter_provider: Optional['metrics.MeterProvider'] = None, ): """ Load the executor to this runtime, specified by ``uses`` CLI argument. @@ -564,9 +564,9 @@ def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): pass def _setup_requests( - self, - requests: List['DataRequest'], - exec_endpoint: str, + self, + requests: List['DataRequest'], + exec_endpoint: str, ): """Execute a request using the executor. @@ -582,7 +582,7 @@ def _setup_requests( return requests, params async def handle_generator( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None ) -> Generator: """Prepares and executes a request for generator endpoints. @@ -619,7 +619,7 @@ async def handle_generator( ) async def handle( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None ) -> DataRequest: """Initialize private parameters and execute private loading functions. @@ -691,7 +691,7 @@ async def handle( @staticmethod def replace_docs( - request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None + request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None ) -> None: """Replaces the docs in a message with new Documents. @@ -739,7 +739,7 @@ async def close(self): @staticmethod def _get_docs_matrix_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> Tuple[Optional[List['DocumentArray']], Optional[Dict[str, 'DocumentArray']]]: """ Returns a docs matrix from a list of DataRequest objects. @@ -763,7 +763,7 @@ def _get_docs_matrix_from_request( @staticmethod def get_parameters_dict_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> 'Dict': """ Returns a parameters dict from a list of DataRequest objects. @@ -783,7 +783,7 @@ def get_parameters_dict_from_request( @staticmethod def get_docs_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> 'DocumentArray': """ Gets a field from the message @@ -863,7 +863,7 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest': # serving part async def process_single_data( - self, request: DataRequest, context, is_generator: bool = False + self, request: DataRequest, context, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request @@ -877,7 +877,7 @@ async def process_single_data( return await self.process_data([request], context, is_generator=is_generator) async def stream_doc( - self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' + self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' ) -> SingleDocumentRequest: """ Process the received requests and return the result as a new request, used for streaming behavior, one doc IN, several out @@ -988,16 +988,19 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: inner_dict['output']['model'] = _create_aux_model_doc_list_to_list( inner_dict['output']['model'] ).schema() + + if inner_dict['parameters']['model'] is not None: + inner_dict['parameters']['model'] = inner_dict['parameters']['model'].schema() else: for endpoint_name, inner_dict in schemas.items(): inner_dict['input']['model'] = inner_dict['input']['model'].schema() inner_dict['output']['model'] = inner_dict['output']['model'].schema() - + inner_dict['parameters'] = {} json_format.ParseDict(schemas, endpoints_proto.schemas) return endpoints_proto def _extract_tracing_context( - self, metadata: 'grpc.aio.Metadata' + self, metadata: 'grpc.aio.Metadata' ) -> Optional['Context']: if self.tracer: from opentelemetry.propagate import extract @@ -1013,7 +1016,7 @@ def _log_data_request(self, request: DataRequest): ) async def process_data( - self, requests: List[DataRequest], context, is_generator: bool = False + self, requests: List[DataRequest], context, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request @@ -1025,7 +1028,7 @@ async def process_data( """ self.logger.debug('recv a process_data request') with MetricsTimer( - self._summary, self._receiving_request_seconds, self._metric_attributes + self._summary, self._receiving_request_seconds, self._metric_attributes ): try: if self.logger.debug_enabled: @@ -1074,8 +1077,8 @@ async def process_data( ) if ( - self.args.exit_on_exceptions - and type(ex).__name__ in self.args.exit_on_exceptions + self.args.exit_on_exceptions + and type(ex).__name__ in self.args.exit_on_exceptions ): self.logger.info('Exiting because of "--exit-on-exceptions".') raise RuntimeTerminated @@ -1100,7 +1103,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: return info_proto async def stream( - self, request_iterator, context=None, *args, **kwargs + self, request_iterator, context=None, *args, **kwargs ) -> AsyncIterator['Request']: """ stream requests from client iterator and stream responses back. @@ -1118,8 +1121,8 @@ async def stream( Call = stream def _create_snapshot_status( - self, - snapshot_directory: str, + self, + snapshot_directory: str, ) -> 'jina_pb2.SnapshotStatusProto': _id = str(uuid.uuid4()) self.logger.debug(f'Generated snapshot id: {_id}') @@ -1132,7 +1135,7 @@ def _create_snapshot_status( ) def _create_restore_status( - self, + self, ) -> 'jina_pb2.SnapshotStatusProto': _id = str(uuid.uuid4()) self.logger.debug(f'Generated restore id: {_id}') @@ -1151,9 +1154,9 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto': """ self.logger.debug(f' Calling snapshot') if ( - self._snapshot - and self._snapshot_thread - and self._snapshot_thread.is_alive() + self._snapshot + and self._snapshot_thread + and self._snapshot_thread.is_alive() ): raise RuntimeError( f'A snapshot with id {self._snapshot.id.value} is currently in progress. Cannot start another.' @@ -1171,7 +1174,7 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto': return self._snapshot async def snapshot_status( - self, request: 'jina_pb2.SnapshotId', context + self, request: 'jina_pb2.SnapshotId', context ) -> 'jina_pb2.SnapshotStatusProto': """ method to start a snapshot process of the Executor @@ -1232,7 +1235,7 @@ async def restore(self, request: 'jina_pb2.RestoreSnapshotCommand', context): return self._restore async def restore_status( - self, request, context + self, request, context ) -> 'jina_pb2.RestoreSnapshotStatusProto': """ method to start a snapshot process of the Executor diff --git a/jina/serve/stream/__init__.py b/jina/serve/stream/__init__.py index b6d919975369f..fe88f625ddb35 100644 --- a/jina/serve/stream/__init__.py +++ b/jina/serve/stream/__init__.py @@ -87,6 +87,7 @@ async def _get_endpoints_input_output_models(self, topology_graph, connection_po previous_output=None, previous_is_generator=None, previous_is_singleton_doc=None, + previous_parameters=None, endpoint=endp) if leaf_input_output_model is not None and len(leaf_input_output_model) > 0: _endpoints_models_map[endp] = leaf_input_output_model[0] diff --git a/jina/types/request/data.py b/jina/types/request/data.py index 1d78b1f37e0ab..427a447854dcc 100644 --- a/jina/types/request/data.py +++ b/jina/types/request/data.py @@ -306,7 +306,12 @@ def parameters(self, value: Dict): :param value: a Python dict """ self.proto_wo_data.parameters.Clear() - self.proto_wo_data.parameters.update(value) + parameters = value + if docarray_v2: + from pydantic import BaseModel + if isinstance(value, BaseModel): + parameters = dict(value) + self.proto_wo_data.parameters.update(parameters) @property def response(self): @@ -655,7 +660,12 @@ def parameters(self, value: Dict): :param value: a Python dict """ self.proto_wo_data.parameters.Clear() - self.proto_wo_data.parameters.update(value) + parameters = value + if docarray_v2: + from pydantic import BaseModel + if isinstance(value, BaseModel): + parameters = dict(value) + self.proto_wo_data.parameters.update(parameters) def __copy__(self): return SingleDocumentRequest(request=self.proto_with_data) diff --git a/tests/integration/docarray_v2/test_parameters_as_pydantic.py b/tests/integration/docarray_v2/test_parameters_as_pydantic.py new file mode 100644 index 0000000000000..60ea26eabc96c --- /dev/null +++ b/tests/integration/docarray_v2/test_parameters_as_pydantic.py @@ -0,0 +1,225 @@ +import pytest +from typing import Dict +from jina import Flow, Deployment, Executor, requests +from docarray import DocList, BaseDoc +from docarray.documents import TextDoc +from jina.helper import random_port +from pydantic import BaseModel + + +@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) +@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow']) +@pytest.mark.parametrize('parameters_in_client', ['dict', 'model']) +def test_parameters_as_pydantic(protocol, ctxt_manager, parameters_in_client): + if ctxt_manager == 'deployment' and protocol == 'websocket': + return + + class Parameters(BaseModel): + param: str + num: int = 5 + + class FooParameterExecutor(Executor): + @requests(on='/hello') + def foo(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocList[TextDoc]: + for doc in docs: + doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' + + @requests(on='/hello_single') + def bar(self, doc: TextDoc, parameters: Parameters, **kwargs) -> TextDoc: + doc.text = f'Processed by bar with param: {parameters.param} and num: {parameters.num}' + + if ctxt_manager == 'flow': + ctxt_mgr = Flow(protocol=protocol).add(uses=FooParameterExecutor) + else: + ctxt_mgr = Deployment(protocol=protocol, uses=FooParameterExecutor) + + params_to_send = {'param': 'value'} if parameters_in_client == 'dict' else Parameters(param='value') + with ctxt_mgr: + ret = ctxt_mgr.post( + on='/hello', + parameters=params_to_send, + inputs=DocList[TextDoc]([TextDoc(text='')]), + ) + assert len(ret) == 1 + assert ret[0].text == 'Processed by foo with param: value and num: 5' + + ret = ctxt_mgr.post( + on='/hello_single', + parameters=params_to_send, + inputs=DocList[TextDoc]([TextDoc(text='')]), + ) + assert len(ret) == 1 + assert ret[0].text == 'Processed by bar with param: value and num: 5' + if protocol == 'http': + import requests as global_requests + for endpoint in {'hello', 'hello_single'}: + processed_by = 'foo' if endpoint == 'hello' else 'bar' + url = f'http://localhost:{ctxt_mgr.port}/{endpoint}' + myobj = {'data': {'text': ''}, 'parameters': {'param': 'value'}} + resp = global_requests.post(url, json=myobj) + resp_json = resp.json() + assert resp_json['data'][0]['text'] == f'Processed by {processed_by} with param: value and num: 5' + myobj = {'data': [{'text': ''}], 'parameters': {'param': 'value'}} + resp = global_requests.post(url, json=myobj) + resp_json = resp.json() + assert resp_json['data'][0]['text'] == f'Processed by {processed_by} with param: value and num: 5' + + +@pytest.mark.parametrize('protocol', ['http', 'websocket', 'grpc']) +@pytest.mark.parametrize('ctxt_manager', ['deployment', 'flow']) +def test_parameters_invalid(protocol, ctxt_manager): + if ctxt_manager == 'deployment' and protocol == 'websocket': + return + + class Parameters(BaseModel): + param: str + num: int + + class FooInvalidParameterExecutor(Executor): + @requests(on='/hello') + def foo(self, docs: DocList[TextDoc], parameters: Parameters, **kwargs) -> DocList[TextDoc]: + for doc in docs: + doc.text += f'Processed by foo with param: {parameters.param} and num: {parameters.num}' + + if ctxt_manager == 'flow': + ctxt_mgr = Flow(protocol=protocol).add(uses=FooInvalidParameterExecutor) + else: + ctxt_mgr = Deployment(protocol=protocol, uses=FooInvalidParameterExecutor) + + params_to_send = {'param': 'value'} + with ctxt_mgr: + with pytest.raises(Exception): + _ = ctxt_mgr.post( + on='/hello', + parameters=params_to_send, + inputs=DocList[TextDoc]([TextDoc(text='')]), + ) + + +@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) +def test_parameters_as_pydantic_in_flow_only_first(protocol): + class Input1(BaseDoc): + text: str + + class Output1(BaseDoc): + price: int + + class Output2(BaseDoc): + a: str + + class ParametersFirst(BaseModel): + mult: int + + class Exec1Chain(Executor): + @requests(on='/bar') + def bar(self, docs: DocList[Input1], parameters: ParametersFirst, **kwargs) -> DocList[Output1]: + docs_return = DocList[Output1]( + [Output1(price=5 * parameters.mult) for _ in range(len(docs))] + ) + return docs_return + + class Exec2Chain(Executor): + @requests(on='/bar') + def bar(self, docs: DocList[Output1], **kwargs) -> DocList[Output2]: + docs_return = DocList[Output2]( + [ + Output2(a=f'final price {docs[0].price}') + for _ in range(len(docs)) + ] + ) + return docs_return + + f = Flow(protocol=protocol).add(uses=Exec1Chain).add(uses=Exec2Chain) + with f: + docs = f.post(on='/bar', inputs=Input1(text='ignored'), parameters={'mult': 10}, return_type=DocList[Output2]) + assert docs[0].a == 'final price 50' + + +@pytest.mark.parametrize('protocol', ['grpc', 'http', 'websocket']) +def test_parameters_as_pydantic_in_flow_second(protocol): + class Input1(BaseDoc): + text: str + + class Output1(BaseDoc): + price: int + + class Output2(BaseDoc): + a: str + + class ParametersSecond(BaseModel): + mult: int + + class Exec1Chain(Executor): + @requests(on='/bar') + def bar(self, docs: DocList[Input1], **kwargs) -> DocList[Output1]: + docs_return = DocList[Output1]( + [Output1(price=5) for _ in range(len(docs))] + ) + return docs_return + + class Exec2Chain(Executor): + @requests(on='/bar') + def bar(self, docs: DocList[Output1], parameters: ParametersSecond, **kwargs) -> DocList[Output2]: + docs_return = DocList[Output2]( + [ + Output2(a=f'final price {docs[0].price * parameters.mult}') + for _ in range(len(docs)) + ] + ) + return docs_return + + f = Flow(protocol=protocol).add(uses=Exec1Chain).add(uses=Exec2Chain) + with f: + docs = f.post(on='/bar', inputs=Input1(text='ignored'), parameters={'mult': 10}, return_type=DocList[Output2]) + assert docs[0].a == 'final price 50' + + +@pytest.mark.parametrize('ctxt_manager', ['flow', 'deployment']) +@pytest.mark.parametrize('include_gateway', [False, True]) +def test_openai(ctxt_manager, include_gateway): + if ctxt_manager == 'flow' and include_gateway: + return + import string + import random + + random_example = ''.join(random.choices(string.ascii_letters, k=10)) + random_description = ''.join(random.choices(string.ascii_letters, k=10)) + from pydantic.fields import Field + from pydantic import BaseModel + class MyDocWithExample(BaseDoc): + """This test should be in description""" + t: str = Field(examples=[random_example], description=random_description) + class Config: + title: str = 'MyDocWithExampleTitle' + schema_extra: Dict = {'extra_key': 'extra_value'} + + class MyConfigParam(BaseModel): + """Configuration for Executor endpoint""" + param1: int = Field(description='batch size', example=256) + + class MyExecDocWithExample(Executor): + @requests + def foo(self, docs: DocList[MyDocWithExample], parameters: MyConfigParam, **kwargs) -> DocList[MyDocWithExample]: + pass + + port = random_port() + + if ctxt_manager == 'flow': + ctxt = Flow(protocol='http', port=port).add(uses=MyExecDocWithExample) + else: + ctxt = Deployment(uses=MyExecDocWithExample, protocol='http', port=port, include_gateway=include_gateway) + + with ctxt: + import requests as general_requests + resp = general_requests.get(f'http://localhost:{port}/openapi.json') + resp_str = str(resp.json()) + assert random_example in resp_str + assert random_description in resp_str + assert 'This test should be in description' in resp_str + assert 'MyDocWithExampleTitle' in resp_str + assert 'extra_key' in resp_str + assert 'MyConfigParam' in resp_str + assert 'Configuration for Executor endpoint' in resp_str + assert 'batch size' in resp_str + assert '256' in resp_str +