Skip to content

Commit

Permalink
fix: loadbalance stream based on response (#6122)
Browse files Browse the repository at this point in the history
Co-authored-by: Joan Fontanals <[email protected]>
  • Loading branch information
NarekA and JoanFM authored Dec 6, 2023
1 parent 1c2a7c2 commit 129b7b3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 46 deletions.
73 changes: 31 additions & 42 deletions jina/serve/runtimes/gateway/request_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,48 +157,37 @@ async def _load_balance(self, request):
try:
async with aiohttp.ClientSession() as session:

if request.method == 'GET':
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')

async with session.get(
url=target_url, **request_kwargs
) as response:
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
headers=response.headers,
)

# Prepare the response to send headers
await stream_response.prepare(request)

# Stream the response from the target server to the client
async for chunk in response.content.iter_any():
await stream_response.write(chunk)

# Close the stream response once all chunks are sent
await stream_response.write_eof()
return stream_response

elif request.method == 'POST':
d = await request.read()
import json

async with session.post(
url=target_url, json=json.loads(d.decode())
) as response:
content = await response.read()
return web.Response(
body=content,
status=response.status,
content_type=response.content_type,
)
request_kwargs = {}
try:
payload = await request.json()
if payload:
request_kwargs['json'] = payload
except Exception:
self.logger.debug('No JSON payload found in request')

async with session.request(
request.method,
url=target_url,
auto_decompress=False,
**request_kwargs,
) as response:
# Create a StreamResponse with the same headers and status as the target response
stream_response = web.StreamResponse(
status=response.status,
headers=response.headers,
)

# Prepare the response to send headers
await stream_response.prepare(request)

# Stream the response from the target server to the client
async for chunk in response.content.iter_any():
await stream_response.write(chunk)

# Close the stream response once all chunks are sent
await stream_response.write_eof()
return stream_response

except aiohttp.ClientError as e:
return web.Response(text=f'Error: {str(e)}', status=500)

Expand Down
6 changes: 2 additions & 4 deletions tests/integration/docarray_v2/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,17 @@ async def test_issue_6090_get_params(streaming_deployment):

docs = []
url = (
f"htto://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
f"http://localhost:{streaming_deployment.port}/stream-simple?text=my_input_text"
)
async with aiohttp.ClientSession() as session:

async with session.get(url) as resp:
async for chunk in resp.content.iter_any():
print(chunk)
events = chunk.split(b'event: ')[1:]
for event in events:
if event.startswith(b'update'):
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX:].decode()
parsed = event[HTTPClientlet.UPDATE_EVENT_PREFIX :].decode()
parsed = SimpleInput.parse_raw(parsed)
print(parsed)
docs.append(parsed)
elif event.startswith(b'end'):
pass
Expand Down

0 comments on commit 129b7b3

Please sign in to comment.