diff --git a/tests/test_basics.py b/tests/test_basics.py index 9da764c..1612e54 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -4,6 +4,7 @@ import sys import pathlib import ssl +import socket import picows import pytest @@ -94,12 +95,11 @@ def __init__(self, server): self.server_task = asyncio.create_task(server.serve_forever()) async def __aenter__(self): - return self.server + return await self.server.__aenter__() - async def __aexit__(self, exc_type, exc_val, exc_tb): - self.server.close() - await self.server.wait_closed() + async def __aexit__(self, *exc): self.server_task.cancel() + await self.server.__aexit__(*exc) with pytest.raises(asyncio.CancelledError): async with async_timeout.timeout(TIMEOUT): await self.server_task @@ -249,7 +249,7 @@ async def test_server_handshake_timeout(): assert client_reader.at_eof() -@pytest.mark.parametrize("request_path", ["/v1/ws?key=blablabla&data=fhhh", "/v1/ws"]) +@pytest.mark.parametrize("request_path", ["/v1/ws", "/v1/ws?key=blablabla&data=fhhh"]) async def test_request_path_and_params(request_path): def listener_factory(request: picows.WSUpgradeRequest): assert request.method == b"GET" @@ -262,7 +262,8 @@ def listener_factory(request: picows.WSUpgradeRequest): "127.0.0.1", 0, websocket_handshake_timeout=0.1) async with ServerAsyncContext(server): url = f"ws://127.0.0.1:{server.sockets[0].getsockname()[1]}{request_path}" - (_, client) = await picows.ws_connect(picows.WSListener, url) + (transport, _) = await picows.ws_connect(picows.WSListener, url) + transport.disconnect() async def test_route_not_found():