Skip to content

Commit

Permalink
Address python/cpython#118950 in uvloop by porting fix and adding tes…
Browse files Browse the repository at this point in the history
…ts to ensure asyncio.streams code effectively can schedule connection_lost and raise ConnectionResetError
  • Loading branch information
cjavad committed Oct 23, 2024
1 parent 7bb12a1 commit 3803ff4
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 3 deletions.
58 changes: 56 additions & 2 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
skip_tests = False

import asyncio
import os
import sys
import unittest
import weakref

from uvloop import _testbase as tb


class _TestAioHTTP:
class _TestAioHTTP(tb.SSLTestCase):

def test_aiohttp_basic_1(self):

Expand Down Expand Up @@ -95,7 +96,7 @@ async def on_shutdown(app):
async def client():
async with aiohttp.ClientSession() as client:
async with client.ws_connect(
'http://127.0.0.1:{}'.format(port)) as ws:
'http://127.0.0.1:{}'.format(port)) as ws:
await ws.send_str("hello")
async for msg in ws:
assert msg.data == "hello"
Expand All @@ -115,6 +116,59 @@ async def stop():

self.loop.run_until_complete(stop())

def test_aiohttp_connection_lost_when_busy(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest('bug in asyncio #118950, tests in CPython.')

cert = tb._cert_fullname(__file__, 'ssl_cert.pem')
key = tb._cert_fullname(__file__, 'ssl_key.pem')
ssl_context = self._create_server_ssl_context(cert, key)
client_ssl_context = self._create_client_ssl_context()

asyncio.set_event_loop(self.loop)
app = aiohttp.web.Application()

async def handler(request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
async for msg in ws:
print("Received:", msg.data)
return ws

app.router.add_get('/', handler)

runner = aiohttp.web.AppRunner(app)
self.loop.run_until_complete(runner.setup())
host = '0.0.0.0'
site = aiohttp.web.TCPSite(runner, host, '0', ssl_context=ssl_context)
self.loop.run_until_complete(site.start())
port = site._server.sockets[0].getsockname()[1]
session = aiohttp.ClientSession(loop=self.loop)

async def test():
async with session.ws_connect(f"wss://{host}:{port}/", ssl=client_ssl_context) as ws:
transport = ws._writer.transport
s = transport.get_extra_info('socket')

if self.implementation == 'asyncio':
s._sock.close()
else:
os.close(s.fileno())

# FLOW_CONTROL_HIGH_WATER * 1024
bytes_to_send = 64 * 1024
iterations = 10
msg = b'Hello world, still there?'

# Send enough messages to trigger a socket write + one extra
for _ in range(iterations + 1):
await ws.send_bytes(msg * ((bytes_to_send // len(msg)) // iterations))

self.assertRaises(ConnectionResetError, self.loop.run_until_complete, test())

self.loop.run_until_complete(session.close())
self.loop.run_until_complete(runner.cleanup())


@unittest.skipIf(skip_tests, "no aiohttp module")
class Test_UV_AioHTTP(_TestAioHTTP, tb.UVTestCase):
Expand Down
47 changes: 47 additions & 0 deletions tests/test_tcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import asyncio.sslproto
import contextlib
import gc
import os
import select
Expand Down Expand Up @@ -3192,6 +3193,52 @@ async def run_main():

self.loop.run_until_complete(run_main())

def test_connection_lost_when_busy(self):
if self.implementation == 'asyncio':
raise unittest.SkipTest('bug in asyncio #118950, tests in CPython.')

ssl_context = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
client_ssl_context = self._create_client_ssl_context()
port = tb.find_free_port()

@contextlib.asynccontextmanager
async def server():
async def client_handler(reader, writer):
...

srv = await asyncio.start_server(client_handler, '0.0.0.0', port, ssl=ssl_context, reuse_port=True)

try:
yield
finally:
srv.close()

async def client():
reader, writer = await asyncio.open_connection('0.0.0.0', port, ssl=client_ssl_context)
transport = writer.transport
s = transport.get_extra_info('socket')

if self.implementation == 'asyncio':
s._sock.close()
else:
os.close(s.fileno())

# FLOW_CONTROL_HIGH_WATER * 1024
bytes_to_send = 64 * 1024
iterations = 10
msg = b'An really important message :)'

# Busy drain loop
for _ in range(iterations + 1):
writer.write(msg * ((bytes_to_send // len(msg)) // iterations))
await writer.drain()

async def test():
async with server():
await client()

self.assertRaises(ConnectionResetError, self.loop.run_until_complete, test())


class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
pass
Expand Down
5 changes: 4 additions & 1 deletion uvloop/sslproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ cdef class _SSLProtocolTransport:
return self._ssl_protocol._app_protocol

def is_closing(self):
return self._closed
return self._closed or self._ssl_protocol._is_transport_closing()

def close(self):
"""Close the transport.
Expand Down Expand Up @@ -316,6 +316,9 @@ cdef class SSLProtocol:
self._app_transport_created = True
return self._app_transport

def _is_transport_closing(self):
return self._transport is not None and self._transport.is_closing()

def connection_made(self, transport):
"""Called when the low-level connection is made.
Expand Down

0 comments on commit 3803ff4

Please sign in to comment.