Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send close frame after ASGI application returned #2331

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 2 additions & 49 deletions tests/protocols/test_http.py

Large diffs are not rendered by default.

192 changes: 37 additions & 155 deletions tests/protocols/test_websocket.py

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions tests/supervisors/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from uvicorn import Server
from uvicorn.config import Config

pytestmark = pytest.mark.anyio


@pytest.mark.anyio
async def test_sigint_finish_req(unused_tcp_port: int):
"""
1. Request is sent
Expand Down Expand Up @@ -42,7 +43,6 @@ async def wait_app(scope, receive, send):
assert req.result().status_code == 200


@pytest.mark.anyio
async def test_sigint_abort_req(unused_tcp_port: int, caplog):
"""
1. Request is sent
Expand Down Expand Up @@ -77,7 +77,6 @@ async def forever_app(scope, receive, send):
assert "Cancel 1 running task(s), timeout graceful shutdown exceeded" in caplog.messages


@pytest.mark.anyio
async def test_sigint_deny_request_after_triggered(unused_tcp_port: int, caplog):
"""
1. Server is started
Expand Down
10 changes: 5 additions & 5 deletions tests/test_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def test():
lifespan = LifespanOn(config)

await lifespan.startup()
assert lifespan.error_occured
assert lifespan.error_occurred
assert not lifespan.should_exit
await lifespan.shutdown()

Expand All @@ -117,7 +117,7 @@ async def test():
lifespan = LifespanOn(config)

await lifespan.startup()
assert lifespan.error_occured
assert lifespan.error_occurred
assert lifespan.should_exit
await lifespan.shutdown()

Expand All @@ -143,7 +143,7 @@ async def test():

await lifespan.startup()
assert lifespan.startup_failed
assert lifespan.error_occured is raise_exception
assert lifespan.error_occurred is raise_exception
assert lifespan.should_exit
await lifespan.shutdown()

Expand Down Expand Up @@ -171,7 +171,7 @@ async def test():

await lifespan.startup()
assert not lifespan.startup_failed
assert not lifespan.error_occured
assert not lifespan.error_occurred
assert not lifespan.should_exit
await lifespan.shutdown()

Expand Down Expand Up @@ -228,7 +228,7 @@ async def test():
assert not lifespan.startup_failed
await lifespan.shutdown()
assert lifespan.shutdown_failed
assert lifespan.error_occured is raise_exception
assert lifespan.error_occurred is raise_exception
assert lifespan.should_exit

loop = asyncio.new_event_loop()
Expand Down
5 changes: 4 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import asyncio
import os
import socket
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path

from uvicorn import Config, Server


@asynccontextmanager
async def run_server(config: Config, sockets=None):
async def run_server(config: Config, sockets: list[socket.socket] | None = None):
server = Server(config=config)
task = asyncio.create_task(server.serve(sockets=sockets))
await asyncio.sleep(0.1)
Expand Down
18 changes: 8 additions & 10 deletions uvicorn/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,18 +275,16 @@ class LifespanShutdownFailedEvent(TypedDict):


class ASGI2Protocol(Protocol):
def __init__(self, scope: Scope) -> None: ... # pragma: no cover
def __init__(self, scope: Scope) -> None: ...

async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ... # pragma: no cover
async def __call__(self, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ...


ASGI2Application = Type[ASGI2Protocol]
ASGI3Application = Callable[
[
Scope,
ASGIReceiveCallable,
ASGISendCallable,
],
Awaitable[None],
]


class ASGI3Application(Protocol):
async def __call__(self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable) -> None: ...


ASGIApplication = Union[ASGI2Application, ASGI3Application]
24 changes: 13 additions & 11 deletions uvicorn/lifespan/on.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import asyncio
import logging
from asyncio import Queue
from typing import Any, Union
from typing import Any, Union, cast

from uvicorn import Config
from uvicorn._types import (
ASGISendEvent,
LifespanScope,
LifespanShutdownCompleteEvent,
LifespanShutdownEvent,
Expand All @@ -16,8 +17,8 @@
LifespanStartupFailedEvent,
)

LifespanReceiveMessage = Union[LifespanStartupEvent, LifespanShutdownEvent]
LifespanSendMessage = Union[
LifespanReceiveEvent = Union[LifespanStartupEvent, LifespanShutdownEvent]
LifespanSendEvent = Union[
LifespanStartupFailedEvent,
LifespanShutdownFailedEvent,
LifespanStartupCompleteEvent,
Expand All @@ -37,8 +38,8 @@ def __init__(self, config: Config) -> None:
self.logger = logging.getLogger("uvicorn.error")
self.startup_event = asyncio.Event()
self.shutdown_event = asyncio.Event()
self.receive_queue: Queue[LifespanReceiveMessage] = asyncio.Queue()
self.error_occured = False
self.receive_queue: Queue[LifespanReceiveEvent] = asyncio.Queue()
self.error_occurred = False
self.startup_failed = False
self.shutdown_failed = False
self.should_exit = False
Expand All @@ -55,21 +56,21 @@ async def startup(self) -> None:
await self.receive_queue.put(startup_event)
await self.startup_event.wait()

if self.startup_failed or (self.error_occured and self.config.lifespan == "on"):
if self.startup_failed or (self.error_occurred and self.config.lifespan == "on"):
self.logger.error("Application startup failed. Exiting.")
self.should_exit = True
else:
self.logger.info("Application startup complete.")

async def shutdown(self) -> None:
if self.error_occured:
if self.error_occurred:
return
self.logger.info("Waiting for application shutdown.")
shutdown_event: LifespanShutdownEvent = {"type": "lifespan.shutdown"}
await self.receive_queue.put(shutdown_event)
await self.shutdown_event.wait()

if self.shutdown_failed or (self.error_occured and self.config.lifespan == "on"):
if self.shutdown_failed or (self.error_occurred and self.config.lifespan == "on"):
self.logger.error("Application shutdown failed. Exiting.")
self.should_exit = True
else:
Expand All @@ -86,7 +87,7 @@ async def main(self) -> None:
await app(scope, self.receive, self.send)
except BaseException as exc:
self.asgi = None
self.error_occured = True
self.error_occurred = True
if self.startup_failed or self.shutdown_failed:
return
if self.config.lifespan == "auto":
Expand All @@ -99,13 +100,14 @@ async def main(self) -> None:
self.startup_event.set()
self.shutdown_event.set()

async def send(self, message: LifespanSendMessage) -> None:
async def send(self, message: ASGISendEvent) -> None:
assert message["type"] in (
"lifespan.startup.complete",
"lifespan.startup.failed",
"lifespan.shutdown.complete",
"lifespan.shutdown.failed",
)
message = cast(LifespanSendEvent, message)

if message["type"] == "lifespan.startup.complete":
assert not self.startup_event.is_set(), STATE_TRANSITION_ERROR
Expand Down Expand Up @@ -133,5 +135,5 @@ async def send(self, message: LifespanSendMessage) -> None:
if message.get("message"):
self.logger.error(message["message"])

async def receive(self) -> LifespanReceiveMessage:
async def receive(self) -> LifespanReceiveEvent:
return await self.receive_queue.get()
32 changes: 15 additions & 17 deletions uvicorn/protocols/websockets/websockets_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from urllib.parse import unquote

import websockets
import websockets.legacy.handshake
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.extensions import ServerExtensionFactory
from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory
from websockets.legacy.server import HTTPResponse
from websockets.server import WebSocketServerProtocol
Expand Down Expand Up @@ -53,6 +55,7 @@ def is_serving(self) -> bool:

class WebSocketProtocol(WebSocketServerProtocol):
extra_headers: list[tuple[str, str]]
logger: logging.Logger | logging.LoggerAdapter[Any]

def __init__(
self,
Expand Down Expand Up @@ -81,7 +84,6 @@ def __init__(
self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment]

# Connection events
self.scope: WebSocketScope
self.handshake_started_event = asyncio.Event()
self.handshake_completed_event = asyncio.Event()
self.closed_event = asyncio.Event()
Expand All @@ -92,7 +94,7 @@ def __init__(

self.ws_server: Server = Server() # type: ignore[assignment]

extensions = []
extensions: list[ServerExtensionFactory] = []
if self.config.ws_per_message_deflate:
extensions.append(ServerPerMessageDeflateFactory())

Expand All @@ -111,9 +113,7 @@ def __init__(
(name.decode("latin-1"), value.decode("latin-1")) for name, value in server_state.default_headers
]

def connection_made( # type: ignore[override]
self, transport: asyncio.Transport
) -> None:
def connection_made(self, transport: asyncio.Transport) -> None: # type: ignore[override]
self.connections.add(self)
self.transport = transport
self.server = get_local_addr(transport)
Expand Down Expand Up @@ -147,10 +147,10 @@ def shutdown(self) -> None:
self.send_500_response()
self.transport.close()

def on_task_complete(self, task: asyncio.Task) -> None:
def on_task_complete(self, task: asyncio.Task[Any]) -> None:
self.tasks.discard(task)

async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None:
async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None:
"""
This hook is called to determine if the websocket should return
an HTTP response and close.
Expand All @@ -161,21 +161,21 @@ async def process_request(self, path: str, headers: Headers) -> HTTPResponse | N
"""
path_portion, _, query_string = path.partition("?")

websockets.legacy.handshake.check_request(headers)
websockets.legacy.handshake.check_request(request_headers)

subprotocols = []
for header in headers.get_all("Sec-WebSocket-Protocol"):
subprotocols: list[str] = []
for header in request_headers.get_all("Sec-WebSocket-Protocol"):
subprotocols.extend([token.strip() for token in header.split(",")])

asgi_headers = [
(name.encode("ascii"), value.encode("ascii", errors="surrogateescape"))
for name, value in headers.raw_items()
for name, value in request_headers.raw_items()
]
path = unquote(path_portion)
full_path = self.root_path + path
full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii")

self.scope = {
self.scope: WebSocketScope = {
"type": "websocket",
"asgi": {"version": self.config.asgi_version, "spec_version": "2.4"},
"http_version": "1.1",
Expand Down Expand Up @@ -256,12 +256,10 @@ async def run_asgi(self) -> None:
msg = "ASGI callable returned without sending handshake."
self.logger.error(msg)
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
await self.handshake_completed_event.wait()
self.transport.close()
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
await self.handshake_completed_event.wait()
self.transport.close()

async def asgi_send(self, message: ASGISendEvent) -> None:
message_type = message["type"]
Expand Down
6 changes: 2 additions & 4 deletions uvicorn/protocols/websockets/wsproto_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,9 @@ async def run_asgi(self) -> None:
msg = "ASGI callable returned without completing handshake."
self.logger.error(msg)
self.send_500_response()
self.transport.close()
elif result is not None:
msg = "ASGI callable should return None, but returned '%s'."
self.logger.error(msg, result)
self.transport.close()
self.logger.error("ASGI callable should return None, but returned '%s'.", result)
self.transport.close()

async def send(self, message: ASGISendEvent) -> None:
await self.writable.wait()
Expand Down
Loading