diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index 52eeeb69..b91a3c10 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -1,96 +1,103 @@ import asyncio import logging import time -from typing import Awaitable, Callable +from functools import wraps +from typing import Any, Callable -from fastapi import HTTPException, Request, Response +from fastapi import Request from fastapi.applications import FastAPI -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse +from fastapi.dependencies.utils import ( + get_body_field, + get_dependant, + get_parameterless_sub_dependant, +) +from fastapi.responses import PlainTextResponse +from fastapi.routing import APIRoute, request_response from starlette.status import HTTP_504_GATEWAY_TIMEOUT -from starlette.types import Message +from starlette.types import ASGIApp, Receive, Scope, Send -from pccommon.logging import get_custom_dimensions from pccommon.tracing import trace_request logger = logging.getLogger(__name__) -async def handle_exceptions( - request: Request, - call_next: Callable[[Request], Awaitable[Response]], -) -> Response: - try: - return await call_next(request) - except HTTPException: - raise - except Exception as e: - logger.exception( - "Exception when handling request", - extra=get_custom_dimensions({"stackTrace": f"{e}"}, request), - ) - raise - - -class RequestTracingMiddleware(BaseHTTPMiddleware): - """Custom middleware to use opencensus request traces - - Middleware implementations that access a Request object directly - will cause subsequent middleware or route handlers to hang. See - - https://github.com/tiangolo/fastapi/issues/394 - - for more details on this implementation. - - An alternative approach is to use dependencies on the APIRouter, but - the stac-fast api implementation makes that difficult without having - to override much of the app initialization. - """ - - def __init__(self, app: FastAPI, service_name: str): - super().__init__(app) +async def http_exception_handler(request: Request, exc: Exception) -> Any: + logger.exception("Exception when handling request", exc_info=exc) + raise + + +def with_timeout( + timeout_seconds: float, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]: + if asyncio.iscoroutinefunction(func): + logger.debug("Adding timeout to function %s", func.__name__) + + @wraps(func) + async def inner(*args: Any, **kwargs: Any) -> Any: + start_time = time.monotonic() + try: + return await asyncio.wait_for( + func(*args, **kwargs), timeout=timeout_seconds + ) + except asyncio.TimeoutError as e: + process_time = time.monotonic() - start_time + # don't have a request object here to get custom dimensions. + log_dimensions = { + "request_time": process_time, + } + logger.exception( + f"Request timeout {e}", + extra=log_dimensions, + ) + + ref_id = log_dimensions.get("ref_id") + debug_msg = ( + f" Debug information for support: {ref_id}" if ref_id else "" + ) + + return PlainTextResponse( + f"The request exceeded the maximum allowed time, please" + " try again. If the issue persists, please contact " + "planetarycomputer@microsoft.com." + f"\n\n{debug_msg}", + status_code=HTTP_504_GATEWAY_TIMEOUT, + ) + + return inner + else: + return func + + return with_timeout_ + + +def add_timeout(app: FastAPI, timeout_seconds: float) -> None: + for route in app.router.routes: + if isinstance(route, APIRoute): + new_endpoint = with_timeout(timeout_seconds)(route.endpoint) + route.endpoint = new_endpoint + route.dependant = get_dependant(path=route.path_format, call=route.endpoint) + for depends in route.dependencies[::-1]: + route.dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant( + depends=depends, path=route.path_format + ), + ) + route.body_field = get_body_field( + dependant=route.dependant, name=route.unique_id + ) + route.app = request_response(route.get_route_handler()) + + +class TraceMiddleware: + def __init__(self, app: ASGIApp, service_name: str): + self.app = app self.service_name = service_name - async def set_body(self, request: Request) -> None: - receive_ = await request._receive() - - async def receive() -> Message: - return receive_ - - request._receive = receive - - async def dispatch( - self, request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - await self.set_body(request) - response = await trace_request(self.service_name, request, call_next) - return response - - -async def timeout_middleware( - request: Request, - call_next: Callable[[Request], Awaitable[Response]], - timeout: int, -) -> Response: - try: - start_time = time.time() - return await asyncio.wait_for(call_next(request), timeout=timeout) - - except asyncio.TimeoutError: - process_time = time.time() - start_time - log_dimensions = get_custom_dimensions({"request_time": process_time}, request) - - logger.exception( - "Request timeout", - extra=log_dimensions, - ) - - ref_id = log_dimensions["custom_dimensions"].get("ref_id") - debug_msg = f"Debug information for support: {ref_id}" if ref_id else "" - - return PlainTextResponse( - f"The request exceeded the maximum allowed time, please try again." - " If the issue persists, please contact planetarycomputer@microsoft.com." - f"\n\n{debug_msg}", - status_code=HTTP_504_GATEWAY_TIMEOUT, - ) + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + request: Request = Request(scope, receive) + await trace_request(self.service_name, request) + + await self.app(scope, receive, send) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 80aabc1e..8e28e6c0 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -1,19 +1,21 @@ import json import logging import re -from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast +from typing import List, Optional, Tuple, Union, cast -from fastapi import Request, Response +import fastapi +from fastapi import Request from opencensus.ext.azure.trace_exporter import AzureExporter +from opencensus.trace import execution_context from opencensus.trace.samplers import ProbabilitySampler from opencensus.trace.span import SpanKind from opencensus.trace.tracer import Tracer +from starlette.datastructures import QueryParams from pccommon.config import get_apis_config from pccommon.constants import ( HTTP_METHOD, HTTP_PATH, - HTTP_STATUS_CODE, HTTP_URL, X_AZURE_REF, X_REQUEST_ENTITY, @@ -41,8 +43,7 @@ async def trace_request( service_name: str, request: Request, - call_next: Callable[[Request], Awaitable[Response]], -) -> Response: +) -> None: """Construct a request trace with custom dimensions""" request_path = request_to_path(request).strip("/") @@ -99,17 +100,6 @@ async def trace_request( attribute_key="item", attribute_value=item_id ) - # Call next middleware - response = await call_next(request) - - # Include response dimensions in the trace - tracer.add_attribute_to_current_span( - attribute_key=HTTP_STATUS_CODE, attribute_value=response.status_code - ) - return response - else: - return await call_next(request) - collection_id_re = re.compile( r".*/collections/?(?P[a-zA-Z0-9\-\%]+)?(/items/(?P.*))?.*" # noqa @@ -122,7 +112,6 @@ async def _collection_item_from_request( ) -> Tuple[Optional[str], Optional[str]]: """Attempt to get collection and item ids from the request path or querystring.""" url = request.url - path = url.path.strip("/") try: collection_id_match = collection_id_re.match(f"{url}") if collection_id_match: @@ -131,8 +120,6 @@ async def _collection_item_from_request( ) item_id = cast(Optional[str], collection_id_match.group("item_id")) return (collection_id, item_id) - elif path.endswith("/search") or path.endswith("/register"): - return await _parse_collection_from_search(request) else: collection_id = request.query_params.get("collection") # Some endpoints, like preview/, take an `items` parameter, but @@ -159,35 +146,6 @@ def _should_trace_request(request: Request) -> bool: ) -async def _parse_collection_from_search( - request: Request, -) -> Tuple[Optional[str], Optional[str]]: - """ - Parse the collection id from a search request. - - The search endpoint is a bit of a special case. If it's a GET, the collection - and item ids are in the querystring. If it's a POST, the collection and item may - be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body. - """ - - if request.method.lower() == "get": - collection_id = request.query_params.get("collections") - item_id = request.query_params.get("ids") - return (collection_id, item_id) - elif request.method.lower() == "post": - try: - body = await request.json() - if "collections" in body: - return _parse_queryjson(body) - elif "filter" in body: - return _parse_cqljson(body["filter"]) - except json.JSONDecodeError: - logger.warning( - "Unable to parse search body as JSON. Ignoring collection parameter." - ) - return (None, None) - - def _parse_cqljson(cql: dict) -> Tuple[Optional[str], Optional[str]]: """ Parse the collection id from a CQL-JSON filter. @@ -232,6 +190,8 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: provided property name, if found. Typical usage will be to provide `collection` and `id`. """ + if cql is None: + return None for _, v in cql.items(): if isinstance(v, dict): result = _iter_cql(v, property_name) @@ -249,3 +209,53 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: return result # No collection was found return None + + +def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None: + """ + Try to add the Collection ID and Item ID from a search to the current span. + """ + collection_id, item_id = parse_collection_from_search( + json.loads(search_json), request.method, request.query_params + ) + parent_span = getattr(request.state, "parent_span", None) + + current_span = execution_context.get_current_span() or parent_span + + if current_span: + if collection_id is not None: + current_span.add_attribute("collection", collection_id) + if item_id is not None: + current_span.add_attribute("item", item_id) + else: + logger.warning("No active or parent span available for adding attributes.") + + +def parse_collection_from_search( + body: dict, + method: str, + query_params: QueryParams, +) -> Tuple[Optional[str], Optional[str]]: + """ + Parse the collection id from a search request. + + The search endpoint is a bit of a special case. If it's a GET, the collection + and item ids are in the querystring. If it's a POST, the collection and item may + be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body. + """ + if method.lower() == "get": + collection_id = query_params.get("collections") + item_id = query_params.get("ids") + return (collection_id, item_id) + elif method.lower() == "post": + try: + if body.get("collections") is not None: + return _parse_queryjson(body) + elif "filter" in body: + return _parse_cqljson(body["filter"]) + except json.JSONDecodeError as e: + logger.warning( + "Unable to parse search body as JSON. Ignoring collection" + f"parameter. {e}" + ) + return (None, None) diff --git a/pccommon/tests/test_timeouts.py b/pccommon/tests/test_timeouts.py index e89ef9ca..b32ce056 100644 --- a/pccommon/tests/test_timeouts.py +++ b/pccommon/tests/test_timeouts.py @@ -1,14 +1,14 @@ import asyncio -import random -from typing import Awaitable, Callable +from typing import Any import pytest -from fastapi import FastAPI, Request, Response -from fastapi.responses import PlainTextResponse +from fastapi import FastAPI + +# from fastapi.responses import PlainTextResponse from httpx import AsyncClient -from starlette.status import HTTP_200_OK, HTTP_504_GATEWAY_TIMEOUT +from starlette.status import HTTP_504_GATEWAY_TIMEOUT -from pccommon.middleware import timeout_middleware +from pccommon.middleware import add_timeout TIMEOUT_SECONDS = 2 BASE_URL = "http://test" @@ -20,80 +20,22 @@ app.state.service_name = "test" -@app.middleware("http") -async def _timeout_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a timeout to all requests.""" - return await timeout_middleware(request, call_next, timeout=TIMEOUT_SECONDS) - - -# Test endpoint to sleep for a configurable amount of time, which may exceed the -# timeout middleware setting -@app.get("/sleep", response_class=PlainTextResponse) -async def route_for_test(t: int) -> str: - await asyncio.sleep(t) - return "Done" - - -# Test endpoint to sleep and confirm that the task is cancelled after the timeout -@app.get("/cancel", response_class=PlainTextResponse) -async def route_for_cancel_test(t: int) -> str: - for i in range(t): - await asyncio.sleep(1) - if i > TIMEOUT_SECONDS: - raise Exception("Task should have been cancelled") - - return "Done" - - -# Test middleware -# =============== - - -async def success_response(client: AsyncClient, timeout: int) -> None: - print("making request") - response = await client.get("/sleep", params={"t": timeout}) - assert response.status_code == HTTP_200_OK - assert response.text == "Done" +@app.get("/asleep") +async def asleep() -> Any: + await asyncio.sleep(1) + return {} -async def timeout_response(client: AsyncClient, timeout: int) -> None: - response = await client.get("/sleep", params={"t": timeout}) - assert response.status_code == HTTP_504_GATEWAY_TIMEOUT - - -@pytest.mark.asyncio -async def test_timeout() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - await timeout_response(client, 10) - +# Run this after registering the routes -@pytest.mark.asyncio -async def test_no_timeout() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - await success_response(client, 1) +add_timeout(app, timeout_seconds=0.001) @pytest.mark.asyncio -async def test_multiple_requests() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - timeout_tasks = [] - for _ in range(100): - t = TIMEOUT_SECONDS + random.randint(1, 10) - timeout_tasks.append(asyncio.ensure_future(timeout_response(client, t))) - - await asyncio.gather(*timeout_tasks) - - success_tasks = [] - for _ in range(100): - t = TIMEOUT_SECONDS - 1 - success_tasks.append(asyncio.ensure_future(success_response(client, t))) +async def test_add_timeout() -> None: - await asyncio.gather(*success_tasks) + client = AsyncClient(app=app, base_url=BASE_URL) + response = await client.get("/asleep") -@pytest.mark.asyncio -async def test_request_cancelled() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - await client.get("/cancel", params={"t": 10}) + assert response.status_code == HTTP_504_GATEWAY_TIMEOUT diff --git a/pcstac/pcstac/client.py b/pcstac/pcstac/client.py index 9e162d74..017f6034 100644 --- a/pcstac/pcstac/client.py +++ b/pcstac/pcstac/client.py @@ -20,6 +20,7 @@ from pccommon.constants import DEFAULT_COLLECTION_REGION from pccommon.logging import get_custom_dimensions from pccommon.redis import back_pressure, cached_result, rate_limit +from pccommon.tracing import add_stac_attributes_from_search from pcstac.config import API_DESCRIPTION, API_LANDING_PAGE_ID, API_TITLE, get_settings from pcstac.contants import ( CACHE_KEY_COLLECTION, @@ -227,6 +228,8 @@ async def _fetch() -> ItemCollection: return item_collection search_json = search_request.json() + add_stac_attributes_from_search(search_json, request) + logger.info( "STAC: Item search body", extra=get_custom_dimensions({"search_body": search_json}, request), diff --git a/pcstac/pcstac/config.py b/pcstac/pcstac/config.py index 851c9ebb..e994042d 100644 --- a/pcstac/pcstac/config.py +++ b/pcstac/pcstac/config.py @@ -98,7 +98,7 @@ class Settings(BaseSettings): api_version: str = f"v{API_VERSION}" rate_limits: RateLimits = RateLimits() back_pressures: BackPressures = BackPressures() - request_timout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) + request_timeout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) def get_tiler_href(self, request: Request) -> str: """Generates the tiler HREF. diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 55870a5a..55ba7950 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -1,9 +1,9 @@ """FastAPI application using PGStac.""" import logging import os -from typing import Any, Awaitable, Callable, Dict +from typing import Any, Dict -from fastapi import FastAPI, HTTPException, Request, Response +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -15,11 +15,7 @@ from starlette.responses import PlainTextResponse from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import ( - RequestTracingMiddleware, - handle_exceptions, - timeout_middleware, -) +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis from pcstac.api import PCStacApi @@ -77,6 +73,10 @@ app.state.service_name = ServiceName.STAC +add_timeout(app, app_settings.request_timeout) + +app.add_middleware(TraceMiddleware, service_name=app.state.service_name) + # Note: If requests are being sent through an application gateway like # nginx-ingress, you may need to configure CORS through that system. app.add_middleware( @@ -86,25 +86,6 @@ allow_headers=["*"], ) -app.add_middleware(RequestTracingMiddleware, service_name=ServiceName.STAC) - - -@app.middleware("http") -async def _timeout_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a timeout to all requests.""" - return await timeout_middleware( - request, call_next, timeout=app_settings.request_timout - ) - - -@app.middleware("http") -async def _handle_exceptions( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - return await handle_exceptions(request, call_next) - @app.on_event("startup") async def startup_event() -> None: @@ -119,13 +100,7 @@ async def shutdown_event() -> None: await close_db_connection(app) -@app.exception_handler(HTTPException) -async def http_exception_handler( - request: Request, exc: HTTPException -) -> PlainTextResponse: - return PlainTextResponse( - str(exc.detail), status_code=exc.status_code, headers=exc.headers - ) +app.add_exception_handler(Exception, http_exception_handler) @app.exception_handler(StarletteHTTPException) diff --git a/pctiler/pctiler/config.py b/pctiler/pctiler/config.py index b2610353..87a89932 100644 --- a/pctiler/pctiler/config.py +++ b/pctiler/pctiler/config.py @@ -46,7 +46,7 @@ class Settings(BaseSettings): default_max_items_per_tile: int = Field( env=DEFAULT_MAX_ITEMS_PER_TILE_ENV_VAR, default=10 ) - request_timout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) + request_timeout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) feature_flags: FeatureFlags = FeatureFlags() diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index ea2c0a75..93d1da3b 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import logging import os -from typing import Awaitable, Callable, Dict, List +from typing import Dict, List -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet @@ -19,11 +19,7 @@ from pccommon.constants import X_REQUEST_ENTITY from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import ( - RequestTracingMiddleware, - handle_exceptions, - timeout_middleware, -) +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pctiler.config import get_settings from pctiler.endpoints import ( @@ -88,27 +84,15 @@ app.include_router(health.health_router, tags=["Liveliness/Readiness"]) -app.add_middleware(RequestTracingMiddleware, service_name=ServiceName.TILER) - - -@app.middleware("http") -async def _timeout_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a timeout to all requests.""" - return await timeout_middleware(request, call_next, timeout=settings.request_timout) - +add_timeout(app, settings.request_timeout) +add_exception_handlers(app, DEFAULT_STATUS_CODES) +add_exception_handlers(app, MOSAIC_STATUS_CODES) -@app.middleware("http") -async def _handle_exceptions( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - return await handle_exceptions(request, call_next) +app.add_exception_handler(Exception, http_exception_handler) -add_exception_handlers(app, DEFAULT_STATUS_CODES) -add_exception_handlers(app, MOSAIC_STATUS_CODES) +app.add_middleware(TraceMiddleware, service_name=app.state.service_name) app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600") app.add_middleware(TotalTimeMiddleware)