Skip to content

Commit

Permalink
fix: fix caching models from all endpoints, inputs and outputs (#6005)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoanFM authored Aug 2, 2023
1 parent 6fd64da commit 103ddb3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jina/serve/runtimes/gateway/graph/topology_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ async def task():
else:
input_model = _create_pydantic_model_from_schema(input_model_schema,
input_model_name,
{})
models_created_by_name)
models_created_by_name[input_model_name] = input_model
input_model = models_created_by_name[input_model_name]
models_schema_list.append(input_model_schema)
Expand All @@ -203,7 +203,7 @@ async def task():
else:
output_model = _create_pydantic_model_from_schema(output_model_schema,
output_model_name,
{})
models_created_by_name)
models_created_by_name[output_model_name] = output_model
output_model = models_created_by_name[output_model_name]
models_schema_list.append(output_model)
Expand All @@ -221,7 +221,7 @@ async def task():
from pydantic import BaseModel
parameters_model = _create_pydantic_model_from_schema(parameters_model_schema,
parameters_model_name,
{},
models_created_by_name,
base_class=BaseModel)
models_created_by_name[parameters_model_name] = parameters_model
parameters_model = models_created_by_name[parameters_model_name]
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/docarray_v2/test_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,3 +1518,22 @@ def foo(self, docs: DocList[MyDocWithExample], **kwargs) -> DocList[MyDocWithExa
assert 'This test should be in description' in resp_str
assert 'MyDocWithExampleTitle' in resp_str
assert 'extra_key' in resp_str


def test_issue_fastapi_multiple_models_same_name():
class MyRandomModel(BaseDoc):
a: str

class MyInputModel(BaseDoc):
b: Optional[MyRandomModel] = None


class MyFailingExecutor(Executor):
@requests(on='/generate')
def generate(self, docs: DocList[MyInputModel], **kwargs) -> DocList[MyRandomModel]:
return DocList[MyRandomModel]([doc.b for doc in docs])

with Flow(protocol='http').add(uses=MyFailingExecutor) as f:
input_doc = MyRandomModel(a='hello world')
res = f.post(on='/generate', inputs=[MyInputModel(b=MyRandomModel(a='hey'))], return_type=DocList[MyRandomModel])
assert res[0].a == 'hey'

0 comments on commit 103ddb3

Please sign in to comment.