From f615e5222b734ab2e4a211f1acb81e50b3d80c29 Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Tue, 26 Mar 2024 12:53:32 -0400 Subject: [PATCH 01/10] Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links --- pccommon/pccommon/middleware.py | 157 +++++++++++++++-------------- pccommon/tests/test_timeouts.py | 90 +++-------------- pcstac/pcstac/config.py | 2 +- pcstac/pcstac/main.py | 41 ++------ pcstac/pcstac/middleware.py | 63 ++++++++++++ pcstac/tests/test_proxy_headers.py | 88 ++++++++++++++++ pctiler/pctiler/config.py | 2 +- pctiler/pctiler/main.py | 30 +----- 8 files changed, 264 insertions(+), 209 deletions(-) create mode 100644 pcstac/pcstac/middleware.py create mode 100644 pcstac/tests/test_proxy_headers.py diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index 52eeeb69..f04a3ac1 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -1,96 +1,99 @@ import asyncio import logging import time -from typing import Awaitable, Callable +from functools import wraps +from typing import Any, Callable -from fastapi import HTTPException, Request, Response +from fastapi import HTTPException, Request from fastapi.applications import FastAPI -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse +from fastapi.dependencies.utils import ( + get_body_field, + get_dependant, + get_parameterless_sub_dependant, +) +from fastapi.responses import PlainTextResponse +from fastapi.routing import APIRoute, request_response from starlette.status import HTTP_504_GATEWAY_TIMEOUT -from starlette.types import Message from pccommon.logging import get_custom_dimensions -from pccommon.tracing import trace_request logger = logging.getLogger(__name__) -async def handle_exceptions( - request: Request, - call_next: Callable[[Request], Awaitable[Response]], -) -> Response: - try: - return await call_next(request) - except HTTPException: +async def http_exception_handler(request: Request, exc: Exception) -> Any: + # Log the exception with additional request info if needed + logger.exception("Exception when handling request", exc_info=exc) + # Return a custom response for HTTPException + if isinstance(exc, HTTPException): raise - except Exception as e: + # Handle other exceptions, possibly with a generic response + else: logger.exception( "Exception when handling request", - extra=get_custom_dimensions({"stackTrace": f"{e}"}, request), + extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request), ) raise -class RequestTracingMiddleware(BaseHTTPMiddleware): - """Custom middleware to use opencensus request traces - - Middleware implementations that access a Request object directly - will cause subsequent middleware or route handlers to hang. See - - https://github.com/tiangolo/fastapi/issues/394 - - for more details on this implementation. - - An alternative approach is to use dependencies on the APIRouter, but - the stac-fast api implementation makes that difficult without having - to override much of the app initialization. - """ - - def __init__(self, app: FastAPI, service_name: str): - super().__init__(app) - self.service_name = service_name - - async def set_body(self, request: Request) -> None: - receive_ = await request._receive() - - async def receive() -> Message: - return receive_ - - request._receive = receive - - async def dispatch( - self, request: Request, call_next: Callable[[Request], Awaitable[Response]] - ) -> Response: - await self.set_body(request) - response = await trace_request(self.service_name, request, call_next) - return response - - -async def timeout_middleware( - request: Request, - call_next: Callable[[Request], Awaitable[Response]], - timeout: int, -) -> Response: - try: - start_time = time.time() - return await asyncio.wait_for(call_next(request), timeout=timeout) - - except asyncio.TimeoutError: - process_time = time.time() - start_time - log_dimensions = get_custom_dimensions({"request_time": process_time}, request) - - logger.exception( - "Request timeout", - extra=log_dimensions, - ) - - ref_id = log_dimensions["custom_dimensions"].get("ref_id") - debug_msg = f"Debug information for support: {ref_id}" if ref_id else "" - - return PlainTextResponse( - f"The request exceeded the maximum allowed time, please try again." - " If the issue persists, please contact planetarycomputer@microsoft.com." - f"\n\n{debug_msg}", - status_code=HTTP_504_GATEWAY_TIMEOUT, - ) +def with_timeout( + timeout_seconds: float, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]: + if asyncio.iscoroutinefunction(func): + logger.debug("Adding timeout to function %s", func.__name__) + + @wraps(func) + async def inner(*args: Any, **kwargs: Any) -> Any: + start_time = time.monotonic() + try: + return await asyncio.wait_for( + func(*args, **kwargs), timeout=timeout_seconds + ) + except asyncio.TimeoutError as e: + process_time = time.monotonic() - start_time + # don't have a request object here to get custom dimensions. + log_dimensions = { + "request_time": process_time, + } + logger.exception( + f"Request timeout {e}", + extra=log_dimensions, + ) + + ref_id = log_dimensions.get("ref_id") + debug_msg = ( + f" Debug information for support: {ref_id}" if ref_id else "" + ) + + return PlainTextResponse( + f"The request exceeded the maximum allowed time, please" + " try again. If the issue persists, please contact " + "planetarycomputer@microsoft.com." + f"\n\n{debug_msg}", + status_code=HTTP_504_GATEWAY_TIMEOUT, + ) + + return inner + else: + return func + + return with_timeout_ + + +def add_timeout(app: FastAPI, timeout_seconds: float) -> None: + for route in app.router.routes: + if isinstance(route, APIRoute): + new_endpoint = with_timeout(timeout_seconds)(route.endpoint) + route.endpoint = new_endpoint + route.dependant = get_dependant(path=route.path_format, call=route.endpoint) + for depends in route.dependencies[::-1]: + route.dependant.dependencies.insert( + 0, + get_parameterless_sub_dependant( + depends=depends, path=route.path_format + ), + ) + route.body_field = get_body_field( + dependant=route.dependant, name=route.unique_id + ) + route.app = request_response(route.get_route_handler()) diff --git a/pccommon/tests/test_timeouts.py b/pccommon/tests/test_timeouts.py index e89ef9ca..b32ce056 100644 --- a/pccommon/tests/test_timeouts.py +++ b/pccommon/tests/test_timeouts.py @@ -1,14 +1,14 @@ import asyncio -import random -from typing import Awaitable, Callable +from typing import Any import pytest -from fastapi import FastAPI, Request, Response -from fastapi.responses import PlainTextResponse +from fastapi import FastAPI + +# from fastapi.responses import PlainTextResponse from httpx import AsyncClient -from starlette.status import HTTP_200_OK, HTTP_504_GATEWAY_TIMEOUT +from starlette.status import HTTP_504_GATEWAY_TIMEOUT -from pccommon.middleware import timeout_middleware +from pccommon.middleware import add_timeout TIMEOUT_SECONDS = 2 BASE_URL = "http://test" @@ -20,80 +20,22 @@ app.state.service_name = "test" -@app.middleware("http") -async def _timeout_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a timeout to all requests.""" - return await timeout_middleware(request, call_next, timeout=TIMEOUT_SECONDS) - - -# Test endpoint to sleep for a configurable amount of time, which may exceed the -# timeout middleware setting -@app.get("/sleep", response_class=PlainTextResponse) -async def route_for_test(t: int) -> str: - await asyncio.sleep(t) - return "Done" - - -# Test endpoint to sleep and confirm that the task is cancelled after the timeout -@app.get("/cancel", response_class=PlainTextResponse) -async def route_for_cancel_test(t: int) -> str: - for i in range(t): - await asyncio.sleep(1) - if i > TIMEOUT_SECONDS: - raise Exception("Task should have been cancelled") - - return "Done" - - -# Test middleware -# =============== - - -async def success_response(client: AsyncClient, timeout: int) -> None: - print("making request") - response = await client.get("/sleep", params={"t": timeout}) - assert response.status_code == HTTP_200_OK - assert response.text == "Done" +@app.get("/asleep") +async def asleep() -> Any: + await asyncio.sleep(1) + return {} -async def timeout_response(client: AsyncClient, timeout: int) -> None: - response = await client.get("/sleep", params={"t": timeout}) - assert response.status_code == HTTP_504_GATEWAY_TIMEOUT - - -@pytest.mark.asyncio -async def test_timeout() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - await timeout_response(client, 10) - +# Run this after registering the routes -@pytest.mark.asyncio -async def test_no_timeout() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - await success_response(client, 1) +add_timeout(app, timeout_seconds=0.001) @pytest.mark.asyncio -async def test_multiple_requests() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - timeout_tasks = [] - for _ in range(100): - t = TIMEOUT_SECONDS + random.randint(1, 10) - timeout_tasks.append(asyncio.ensure_future(timeout_response(client, t))) - - await asyncio.gather(*timeout_tasks) - - success_tasks = [] - for _ in range(100): - t = TIMEOUT_SECONDS - 1 - success_tasks.append(asyncio.ensure_future(success_response(client, t))) +async def test_add_timeout() -> None: - await asyncio.gather(*success_tasks) + client = AsyncClient(app=app, base_url=BASE_URL) + response = await client.get("/asleep") -@pytest.mark.asyncio -async def test_request_cancelled() -> None: - async with AsyncClient(app=app, base_url=BASE_URL) as client: - await client.get("/cancel", params={"t": 10}) + assert response.status_code == HTTP_504_GATEWAY_TIMEOUT diff --git a/pcstac/pcstac/config.py b/pcstac/pcstac/config.py index 851c9ebb..e994042d 100644 --- a/pcstac/pcstac/config.py +++ b/pcstac/pcstac/config.py @@ -98,7 +98,7 @@ class Settings(BaseSettings): api_version: str = f"v{API_VERSION}" rate_limits: RateLimits = RateLimits() back_pressures: BackPressures = BackPressures() - request_timout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) + request_timeout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) def get_tiler_href(self, request: Request) -> str: """Generates the tiler HREF. diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 55870a5a..8d3fc42c 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -1,9 +1,10 @@ """FastAPI application using PGStac.""" import logging import os -from typing import Any, Awaitable, Callable, Dict +from typing import Any, Dict -from fastapi import FastAPI, HTTPException, Request, Response +from brotli_asgi import BrotliMiddleware +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -16,9 +17,8 @@ from pccommon.logging import ServiceName, init_logging from pccommon.middleware import ( - RequestTracingMiddleware, - handle_exceptions, - timeout_middleware, + add_timeout, + http_exception_handler, ) from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis @@ -32,6 +32,7 @@ get_settings, ) from pcstac.errors import PC_DEFAULT_STATUS_CODES +from pcstac.middleware import ProxyHeaderHostMiddleware from pcstac.search import PCSearch, PCSearchGetRequest, RedisBaseItemCache DEBUG: bool = os.getenv("DEBUG") == "TRUE" or False @@ -70,6 +71,7 @@ search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, response_class=ORJSONResponse, + middlewares=[BrotliMiddleware, ProxyHeaderHostMiddleware], exceptions={**DEFAULT_STATUS_CODES, **PC_DEFAULT_STATUS_CODES}, ) @@ -77,6 +79,8 @@ app.state.service_name = ServiceName.STAC +add_timeout(app, app_settings.request_timeout) + # Note: If requests are being sent through an application gateway like # nginx-ingress, you may need to configure CORS through that system. app.add_middleware( @@ -86,25 +90,6 @@ allow_headers=["*"], ) -app.add_middleware(RequestTracingMiddleware, service_name=ServiceName.STAC) - - -@app.middleware("http") -async def _timeout_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a timeout to all requests.""" - return await timeout_middleware( - request, call_next, timeout=app_settings.request_timout - ) - - -@app.middleware("http") -async def _handle_exceptions( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - return await handle_exceptions(request, call_next) - @app.on_event("startup") async def startup_event() -> None: @@ -119,13 +104,7 @@ async def shutdown_event() -> None: await close_db_connection(app) -@app.exception_handler(HTTPException) -async def http_exception_handler( - request: Request, exc: HTTPException -) -> PlainTextResponse: - return PlainTextResponse( - str(exc.detail), status_code=exc.status_code, headers=exc.headers - ) +app.add_exception_handler(Exception, http_exception_handler) @app.exception_handler(StarletteHTTPException) diff --git a/pcstac/pcstac/middleware.py b/pcstac/pcstac/middleware.py new file mode 100644 index 00000000..85449f13 --- /dev/null +++ b/pcstac/pcstac/middleware.py @@ -0,0 +1,63 @@ +from http.client import HTTP_PORT, HTTPS_PORT +from typing import List + +from stac_fastapi.api.middleware import ProxyHeaderMiddleware +from starlette.types import Receive, Scope, Send + + +class ProxyHeaderHostMiddleware(ProxyHeaderMiddleware): + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + """Call from stac-fastapi framework.""" + if scope["type"] == "http": + proto, domain, port = self._get_forwarded_url_parts(scope) + proto = self._get_header_value_by_name(scope, "x-forwarded-scheme", proto) + domain = self.get_preferred_domain(scope, default_value=domain) + + if domain is not None: + # A port may already be included from the # x-forwarded-host header + domain_has_port = domain.find(":") > -1 + + port_suffix = "" + if not domain_has_port and port is not None: + if (proto == "http" and port != HTTP_PORT) or ( + proto == "https" and port != HTTPS_PORT + ): + port_suffix = f":{port}" + + scope["headers"] = self._replace_header_value_by_name( + scope, + "host", + f"{domain}{port_suffix}", + ) + scope["scheme"] = proto + + await self.app(scope, receive, send) + + def get_forwarded_hosts(self, scope: Scope) -> List[str]: + """ + x-forwarded-host may contain a CSV of hosts, and also itself may show up + multiple times within the scope. + """ + hosts: List[str] = [ + host.decode() + for key, host in scope["headers"] + if key == b"x-forwarded-host" + ] + return [item for maybe_csv in hosts for item in maybe_csv.split(",")] + + def get_preferred_domain(self, scope: Scope, default_value: str) -> str: + """ + Determine and return the most appropriate value for the host domain from + a list of possible sources. + """ + tlds = (".com", ".org", ".net") + domains = self.get_forwarded_hosts(scope) + + for d in domains: + if d.endswith(tlds): + return d + + if len(domains) == 0: + return default_value + + return domains[0] diff --git a/pcstac/tests/test_proxy_headers.py b/pcstac/tests/test_proxy_headers.py new file mode 100644 index 00000000..b9e14224 --- /dev/null +++ b/pcstac/tests/test_proxy_headers.py @@ -0,0 +1,88 @@ +from unittest.mock import AsyncMock + +import pytest + +from pcstac.middleware import ProxyHeaderHostMiddleware + +inputs = [ + [ + [ + (b"host", b"badhost"), + (b"x-forwarded-scheme", b"https"), + (b"x-forwarded-host", b"example"), + (b"x-forwarded-port", b"8000"), + ], + b"example:8000", + ], + [ + [ + (b"host", b"badhost"), + (b"x-forwarded-scheme", b"https"), + (b"x-forwarded-host", b"badhost,example.net"), + (b"x-forwarded-port", b"8000"), + ], + b"example.net:8000", + ], + [ + [ + (b"host", b"badhost"), + (b"x-forwarded-scheme", b"https"), + (b"x-forwarded-host", b"badhost"), + (b"x-forwarded-host", b"example.net"), + (b"x-forwarded-port", b"8000"), + ], + b"example.net:8000", + ], + [ + [ + (b"host", b"badhost"), + (b"x-forwarded-scheme", b"https"), + (b"x-forwarded-host", b"badhost"), + (b"x-forwarded-host", b"alsobad,example.net"), + (b"x-forwarded-port", b"8000"), + ], + b"example.net:8000", + ], + [ + [ + (b"host", b"badhost:8080"), + (b"x-forwarded-scheme", b"https"), + (b"x-forwarded-host", b"localhost:8080"), + ], + b"localhost:8080", + ], + [ + [ + (b"host", b"goodhost:8080"), + (b"x-forwarded-scheme", b"https"), + ], + b"goodhost:8080", + ], +] + + +@pytest.mark.parametrize("scope_headers,output_host", inputs) +async def test_forwarded_for_middleware(scope_headers, output_host): + middleware = ProxyHeaderHostMiddleware(app=AsyncMock()) + + scope = { + "type": "http", + "headers": scope_headers, + } + receive = AsyncMock() + send = AsyncMock() + + await middleware(scope, receive, send) + + assert (b"host", output_host) in scope[ + "headers" + ], "Expected host to match x-forwarded_* values" + + expected_scheme = [ + value.decode() for key, value in scope_headers if key == b"x-forwarded-scheme" + ][0] + assert ( + scope["scheme"] == expected_scheme + ), "Expected scheme to match x-forwarded-scheme value" + + middleware.app.assert_called_once_with(scope, receive, send) diff --git a/pctiler/pctiler/config.py b/pctiler/pctiler/config.py index b2610353..87a89932 100644 --- a/pctiler/pctiler/config.py +++ b/pctiler/pctiler/config.py @@ -46,7 +46,7 @@ class Settings(BaseSettings): default_max_items_per_tile: int = Field( env=DEFAULT_MAX_ITEMS_PER_TILE_ENV_VAR, default=10 ) - request_timout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) + request_timeout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30) feature_flags: FeatureFlags = FeatureFlags() diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index ea2c0a75..b6be2bdc 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import logging import os -from typing import Awaitable, Callable, Dict, List +from typing import Dict, List -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet @@ -19,11 +19,7 @@ from pccommon.constants import X_REQUEST_ENTITY from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import ( - RequestTracingMiddleware, - handle_exceptions, - timeout_middleware, -) +from pccommon.middleware import add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pctiler.config import get_settings from pctiler.endpoints import ( @@ -88,24 +84,8 @@ app.include_router(health.health_router, tags=["Liveliness/Readiness"]) -app.add_middleware(RequestTracingMiddleware, service_name=ServiceName.TILER) - - -@app.middleware("http") -async def _timeout_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a timeout to all requests.""" - return await timeout_middleware(request, call_next, timeout=settings.request_timout) - - -@app.middleware("http") -async def _handle_exceptions( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - return await handle_exceptions(request, call_next) - - +app.add_exception_handler(Exception, http_exception_handler) +add_timeout(app, settings.request_timeout) add_exception_handlers(app, DEFAULT_STATUS_CODES) add_exception_handlers(app, MOSAIC_STATUS_CODES) From b6120b2b1f07b4636a7bd98977630de02bbf8fc5 Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Wed, 27 Mar 2024 11:41:34 -0400 Subject: [PATCH 02/10] lint --- pcstac/pcstac/main.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 8d3fc42c..c4075dfd 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -16,10 +16,7 @@ from starlette.responses import PlainTextResponse from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import ( - add_timeout, - http_exception_handler, -) +from pccommon.middleware import add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis from pcstac.api import PCStacApi From f0da39430c7035bb2c8474cd71f0aa928bf41eaa Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Thu, 28 Mar 2024 11:58:37 -0400 Subject: [PATCH 03/10] remove BrotliMiddleware,ProxyHeaderHostMiddleware, proxytests --- pcstac/pcstac/main.py | 3 - pcstac/pcstac/middleware.py | 63 --------------------- pcstac/tests/test_proxy_headers.py | 88 ------------------------------ 3 files changed, 154 deletions(-) delete mode 100644 pcstac/pcstac/middleware.py delete mode 100644 pcstac/tests/test_proxy_headers.py diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index c4075dfd..09304224 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -3,7 +3,6 @@ import os from typing import Any, Dict -from brotli_asgi import BrotliMiddleware from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi @@ -29,7 +28,6 @@ get_settings, ) from pcstac.errors import PC_DEFAULT_STATUS_CODES -from pcstac.middleware import ProxyHeaderHostMiddleware from pcstac.search import PCSearch, PCSearchGetRequest, RedisBaseItemCache DEBUG: bool = os.getenv("DEBUG") == "TRUE" or False @@ -68,7 +66,6 @@ search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, response_class=ORJSONResponse, - middlewares=[BrotliMiddleware, ProxyHeaderHostMiddleware], exceptions={**DEFAULT_STATUS_CODES, **PC_DEFAULT_STATUS_CODES}, ) diff --git a/pcstac/pcstac/middleware.py b/pcstac/pcstac/middleware.py deleted file mode 100644 index 85449f13..00000000 --- a/pcstac/pcstac/middleware.py +++ /dev/null @@ -1,63 +0,0 @@ -from http.client import HTTP_PORT, HTTPS_PORT -from typing import List - -from stac_fastapi.api.middleware import ProxyHeaderMiddleware -from starlette.types import Receive, Scope, Send - - -class ProxyHeaderHostMiddleware(ProxyHeaderMiddleware): - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - """Call from stac-fastapi framework.""" - if scope["type"] == "http": - proto, domain, port = self._get_forwarded_url_parts(scope) - proto = self._get_header_value_by_name(scope, "x-forwarded-scheme", proto) - domain = self.get_preferred_domain(scope, default_value=domain) - - if domain is not None: - # A port may already be included from the # x-forwarded-host header - domain_has_port = domain.find(":") > -1 - - port_suffix = "" - if not domain_has_port and port is not None: - if (proto == "http" and port != HTTP_PORT) or ( - proto == "https" and port != HTTPS_PORT - ): - port_suffix = f":{port}" - - scope["headers"] = self._replace_header_value_by_name( - scope, - "host", - f"{domain}{port_suffix}", - ) - scope["scheme"] = proto - - await self.app(scope, receive, send) - - def get_forwarded_hosts(self, scope: Scope) -> List[str]: - """ - x-forwarded-host may contain a CSV of hosts, and also itself may show up - multiple times within the scope. - """ - hosts: List[str] = [ - host.decode() - for key, host in scope["headers"] - if key == b"x-forwarded-host" - ] - return [item for maybe_csv in hosts for item in maybe_csv.split(",")] - - def get_preferred_domain(self, scope: Scope, default_value: str) -> str: - """ - Determine and return the most appropriate value for the host domain from - a list of possible sources. - """ - tlds = (".com", ".org", ".net") - domains = self.get_forwarded_hosts(scope) - - for d in domains: - if d.endswith(tlds): - return d - - if len(domains) == 0: - return default_value - - return domains[0] diff --git a/pcstac/tests/test_proxy_headers.py b/pcstac/tests/test_proxy_headers.py deleted file mode 100644 index b9e14224..00000000 --- a/pcstac/tests/test_proxy_headers.py +++ /dev/null @@ -1,88 +0,0 @@ -from unittest.mock import AsyncMock - -import pytest - -from pcstac.middleware import ProxyHeaderHostMiddleware - -inputs = [ - [ - [ - (b"host", b"badhost"), - (b"x-forwarded-scheme", b"https"), - (b"x-forwarded-host", b"example"), - (b"x-forwarded-port", b"8000"), - ], - b"example:8000", - ], - [ - [ - (b"host", b"badhost"), - (b"x-forwarded-scheme", b"https"), - (b"x-forwarded-host", b"badhost,example.net"), - (b"x-forwarded-port", b"8000"), - ], - b"example.net:8000", - ], - [ - [ - (b"host", b"badhost"), - (b"x-forwarded-scheme", b"https"), - (b"x-forwarded-host", b"badhost"), - (b"x-forwarded-host", b"example.net"), - (b"x-forwarded-port", b"8000"), - ], - b"example.net:8000", - ], - [ - [ - (b"host", b"badhost"), - (b"x-forwarded-scheme", b"https"), - (b"x-forwarded-host", b"badhost"), - (b"x-forwarded-host", b"alsobad,example.net"), - (b"x-forwarded-port", b"8000"), - ], - b"example.net:8000", - ], - [ - [ - (b"host", b"badhost:8080"), - (b"x-forwarded-scheme", b"https"), - (b"x-forwarded-host", b"localhost:8080"), - ], - b"localhost:8080", - ], - [ - [ - (b"host", b"goodhost:8080"), - (b"x-forwarded-scheme", b"https"), - ], - b"goodhost:8080", - ], -] - - -@pytest.mark.parametrize("scope_headers,output_host", inputs) -async def test_forwarded_for_middleware(scope_headers, output_host): - middleware = ProxyHeaderHostMiddleware(app=AsyncMock()) - - scope = { - "type": "http", - "headers": scope_headers, - } - receive = AsyncMock() - send = AsyncMock() - - await middleware(scope, receive, send) - - assert (b"host", output_host) in scope[ - "headers" - ], "Expected host to match x-forwarded_* values" - - expected_scheme = [ - value.decode() for key, value in scope_headers if key == b"x-forwarded-scheme" - ][0] - assert ( - scope["scheme"] == expected_scheme - ), "Expected scheme to match x-forwarded-scheme value" - - middleware.app.assert_called_once_with(scope, receive, send) From 78c065e50a7b49704fc6517cacdd4928254b4db4 Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Tue, 9 Apr 2024 12:38:02 -0400 Subject: [PATCH 04/10] add tracing middleware --- pccommon/pccommon/tracing.py | 54 ++++++++++++++++++++++++++++++++++++ pccommon/setup.py | 2 ++ pcstac/pcstac/client.py | 3 ++ pcstac/pcstac/main.py | 14 ++++++++-- pctiler/pctiler/main.py | 14 ++++++++-- 5 files changed, 83 insertions(+), 4 deletions(-) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 80aabc1e..13b5b75f 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -3,11 +3,14 @@ import re from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast +import fastapi from fastapi import Request, Response from opencensus.ext.azure.trace_exporter import AzureExporter from opencensus.trace.samplers import ProbabilitySampler from opencensus.trace.span import SpanKind from opencensus.trace.tracer import Tracer +from opentelemetry import trace +from starlette.datastructures import QueryParams from pccommon.config import get_apis_config from pccommon.constants import ( @@ -25,6 +28,11 @@ logger = logging.getLogger(__name__) +COLLECTION = "spatio.collection" +COLLECTIONS = "spatio.collections" +ITEM = "spatio.item" +ITEMS = "spatio.items" + exporter = ( AzureExporter( connection_string=( @@ -249,3 +257,49 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: return result # No collection was found return None + + +def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None: + """ + Try to add the Collection ID and Item ID from a search to the current span. + """ + collection_id, item_id = parse_collection_from_search( + json.loads(search_json), request.method, request.query_params + ) + span = trace.get_current_span() + + if collection_id is not None: + span.set_attribute(COLLECTIONS, collection_id) + + if item_id is not None: + span.set_attribute(ITEMS, item_id) + + +def parse_collection_from_search( + body: dict, + method: str, + query_params: QueryParams, +) -> Tuple[Optional[str], Optional[str]]: + """ + Parse the collection id from a search request. + + The search endpoint is a bit of a special case. If it's a GET, the collection + and item ids are in the querystring. If it's a POST, the collection and item may + be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body. + """ + if method.lower() == "get": + collection_id = query_params.get("collections") + item_id = query_params.get("ids") + return (collection_id, item_id) + elif method.lower() == "post": + try: + if "collections" in body: + return _parse_queryjson(body) + elif "filter" in body: + return _parse_cqljson(body["filter"]) + except json.JSONDecodeError as e: + logger.warning( + "Unable to parse search body as JSON. Ignoring collection" + f"parameter. {e}" + ) + return (None, None) diff --git a/pccommon/setup.py b/pccommon/setup.py index 60cc7e8e..4197ae50 100644 --- a/pccommon/setup.py +++ b/pccommon/setup.py @@ -20,6 +20,8 @@ "html-sanitizer==2.4", # Soon available as lxml[html_clean] "lxml_html_clean==0.1.0", + "opentelemetry-api==1.21.0", + "opentelemetry-sdk==1.21.0", ] extra_reqs = { diff --git a/pcstac/pcstac/client.py b/pcstac/pcstac/client.py index 9e162d74..017f6034 100644 --- a/pcstac/pcstac/client.py +++ b/pcstac/pcstac/client.py @@ -20,6 +20,7 @@ from pccommon.constants import DEFAULT_COLLECTION_REGION from pccommon.logging import get_custom_dimensions from pccommon.redis import back_pressure, cached_result, rate_limit +from pccommon.tracing import add_stac_attributes_from_search from pcstac.config import API_DESCRIPTION, API_LANDING_PAGE_ID, API_TITLE, get_settings from pcstac.contants import ( CACHE_KEY_COLLECTION, @@ -227,6 +228,8 @@ async def _fetch() -> ItemCollection: return item_collection search_json = search_request.json() + add_stac_attributes_from_search(search_json, request) + logger.info( "STAC: Item search body", extra=get_custom_dimensions({"search_body": search_json}, request), diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 09304224..405ae28c 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -1,9 +1,9 @@ """FastAPI application using PGStac.""" import logging import os -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Response from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -18,6 +18,7 @@ from pccommon.middleware import add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis +from pccommon.tracing import trace_request from pcstac.api import PCStacApi from pcstac.client import PCClient from pcstac.config import ( @@ -75,6 +76,15 @@ add_timeout(app, app_settings.request_timeout) + +@app.middleware("http") +async def _request_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +) -> Response: + """Add a trace to all requests.""" + return await trace_request(ServiceName.STAC, request, call_next) + + # Note: If requests are being sent through an application gateway like # nginx-ingress, you may need to configure CORS through that system. app.add_middleware( diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index b6be2bdc..3553d0f4 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import logging import os -from typing import Dict, List +from typing import Awaitable, Callable, Dict, List -from fastapi import FastAPI +from fastapi import FastAPI, Request, Response from fastapi.openapi.utils import get_openapi from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet @@ -21,6 +21,7 @@ from pccommon.logging import ServiceName, init_logging from pccommon.middleware import add_timeout, http_exception_handler from pccommon.openapi import fixup_schema +from pccommon.tracing import trace_request from pctiler.config import get_settings from pctiler.endpoints import ( configuration, @@ -89,6 +90,15 @@ add_exception_handlers(app, DEFAULT_STATUS_CODES) add_exception_handlers(app, MOSAIC_STATUS_CODES) + +@app.middleware("http") +async def _request_middleware( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +) -> Response: + """Add a trace to all requests.""" + return await trace_request(ServiceName.TILER, request, call_next) + + app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600") app.add_middleware(TotalTimeMiddleware) From 2fc1ebd3d5b1850a81de1f97eed9cf6d97a212dd Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Wed, 10 Apr 2024 10:17:16 -0400 Subject: [PATCH 05/10] Starlette trace middleware, remove opentelemetry --- pccommon/pccommon/middleware.py | 26 +++++++++++++++++++++++++- pccommon/pccommon/tracing.py | 32 ++++++++++++++++++++------------ pccommon/setup.py | 2 -- pcstac/pcstac/main.py | 16 ++++------------ pctiler/pctiler/main.py | 16 ++++------------ 5 files changed, 53 insertions(+), 39 deletions(-) diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index f04a3ac1..f670563e 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any, Callable -from fastapi import HTTPException, Request +from fastapi import HTTPException, Request, Response from fastapi.applications import FastAPI from fastapi.dependencies.utils import ( get_body_field, @@ -14,8 +14,10 @@ from fastapi.responses import PlainTextResponse from fastapi.routing import APIRoute, request_response from starlette.status import HTTP_504_GATEWAY_TIMEOUT +from starlette.types import ASGIApp, Receive, Scope, Send from pccommon.logging import get_custom_dimensions +from pccommon.tracing import trace_request logger = logging.getLogger(__name__) @@ -97,3 +99,25 @@ def add_timeout(app: FastAPI, timeout_seconds: float) -> None: dependant=route.dependant, name=route.unique_id ) route.app = request_response(route.get_route_handler()) + + +class TraceMiddleware: + def __init__(self, app: ASGIApp, service_name: str): + self.app = app + self.service_name = service_name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + request: Request = Request(scope, receive) + + async def call_next(request: Request) -> Response: + # Create a response object to mimic trace_requests call_next + # argument + response = Response() + await self.app(scope, receive, send) + return response + + await trace_request(self.service_name, request, call_next) + + else: + await self.app(scope, receive, send) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 13b5b75f..7da7ec8a 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -9,7 +9,6 @@ from opencensus.trace.samplers import ProbabilitySampler from opencensus.trace.span import SpanKind from opencensus.trace.tracer import Tracer -from opentelemetry import trace from starlette.datastructures import QueryParams from pccommon.config import get_apis_config @@ -28,11 +27,6 @@ logger = logging.getLogger(__name__) -COLLECTION = "spatio.collection" -COLLECTIONS = "spatio.collections" -ITEM = "spatio.item" -ITEMS = "spatio.items" - exporter = ( AzureExporter( connection_string=( @@ -266,13 +260,27 @@ def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) collection_id, item_id = parse_collection_from_search( json.loads(search_json), request.method, request.query_params ) - span = trace.get_current_span() - - if collection_id is not None: - span.set_attribute(COLLECTIONS, collection_id) + tracer = Tracer( + exporter=exporter, + sampler=ProbabilitySampler(1.0), + ) - if item_id is not None: - span.set_attribute(ITEMS, item_id) + with tracer.span("main") as span: + if ( + hasattr(request.state, "parent_span") + and request.state.parent_span is not None + ): + request.state.parent_span = span + if collection_id is not None: + tracer.add_attribute_to_current_span( + attribute_key="collection", attribute_value=collection_id + ) + if item_id is not None: + tracer.add_attribute_to_current_span( + attribute_key="item", attribute_value=item_id + ) + else: + logger.warning("No 'parent_span' attribute found in request.state") def parse_collection_from_search( diff --git a/pccommon/setup.py b/pccommon/setup.py index 4197ae50..60cc7e8e 100644 --- a/pccommon/setup.py +++ b/pccommon/setup.py @@ -20,8 +20,6 @@ "html-sanitizer==2.4", # Soon available as lxml[html_clean] "lxml_html_clean==0.1.0", - "opentelemetry-api==1.21.0", - "opentelemetry-sdk==1.21.0", ] extra_reqs = { diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 405ae28c..55ba7950 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -1,9 +1,9 @@ """FastAPI application using PGStac.""" import logging import os -from typing import Any, Awaitable, Callable, Dict +from typing import Any, Dict -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -15,10 +15,9 @@ from starlette.responses import PlainTextResponse from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import add_timeout, http_exception_handler +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis -from pccommon.tracing import trace_request from pcstac.api import PCStacApi from pcstac.client import PCClient from pcstac.config import ( @@ -76,14 +75,7 @@ add_timeout(app, app_settings.request_timeout) - -@app.middleware("http") -async def _request_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a trace to all requests.""" - return await trace_request(ServiceName.STAC, request, call_next) - +app.add_middleware(TraceMiddleware, service_name=app.state.service_name) # Note: If requests are being sent through an application gateway like # nginx-ingress, you may need to configure CORS through that system. diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index 3553d0f4..cb32b9dc 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import logging import os -from typing import Awaitable, Callable, Dict, List +from typing import Dict, List -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet @@ -19,9 +19,8 @@ from pccommon.constants import X_REQUEST_ENTITY from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import add_timeout, http_exception_handler +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema -from pccommon.tracing import trace_request from pctiler.config import get_settings from pctiler.endpoints import ( configuration, @@ -91,14 +90,7 @@ add_exception_handlers(app, MOSAIC_STATUS_CODES) -@app.middleware("http") -async def _request_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a trace to all requests.""" - return await trace_request(ServiceName.TILER, request, call_next) - - +app.add_middleware(TraceMiddleware, service_name=app.state.service_name) app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600") app.add_middleware(TotalTimeMiddleware) From 3d4ca72deb3b4fcb1cb850f2d69da1b95a9dd8ff Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Wed, 10 Apr 2024 12:10:50 -0400 Subject: [PATCH 06/10] remove duplicate search tracing, use execution_context as span --- pccommon/pccommon/middleware.py | 31 ++----------- pccommon/pccommon/tracing.py | 82 ++++++--------------------------- pcstac/pcstac/main.py | 12 +++-- pctiler/pctiler/main.py | 15 ++++-- 4 files changed, 37 insertions(+), 103 deletions(-) diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index f670563e..962e5660 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any, Callable -from fastapi import HTTPException, Request, Response +from fastapi import Request from fastapi.applications import FastAPI from fastapi.dependencies.utils import ( get_body_field, @@ -16,27 +16,11 @@ from starlette.status import HTTP_504_GATEWAY_TIMEOUT from starlette.types import ASGIApp, Receive, Scope, Send -from pccommon.logging import get_custom_dimensions from pccommon.tracing import trace_request logger = logging.getLogger(__name__) -async def http_exception_handler(request: Request, exc: Exception) -> Any: - # Log the exception with additional request info if needed - logger.exception("Exception when handling request", exc_info=exc) - # Return a custom response for HTTPException - if isinstance(exc, HTTPException): - raise - # Handle other exceptions, possibly with a generic response - else: - logger.exception( - "Exception when handling request", - extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request), - ) - raise - - def with_timeout( timeout_seconds: float, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -109,15 +93,6 @@ def __init__(self, app: ASGIApp, service_name: str): async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": request: Request = Request(scope, receive) + await trace_request(self.service_name, request) - async def call_next(request: Request) -> Response: - # Create a response object to mimic trace_requests call_next - # argument - response = Response() - await self.app(scope, receive, send) - return response - - await trace_request(self.service_name, request, call_next) - - else: - await self.app(scope, receive, send) + await self.app(scope, receive, send) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 7da7ec8a..8c873eb1 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -1,11 +1,12 @@ import json import logging import re -from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast +from typing import List, Optional, Tuple, Union, cast import fastapi -from fastapi import Request, Response +from fastapi import Request from opencensus.ext.azure.trace_exporter import AzureExporter +from opencensus.trace import execution_context from opencensus.trace.samplers import ProbabilitySampler from opencensus.trace.span import SpanKind from opencensus.trace.tracer import Tracer @@ -15,7 +16,6 @@ from pccommon.constants import ( HTTP_METHOD, HTTP_PATH, - HTTP_STATUS_CODE, HTTP_URL, X_AZURE_REF, X_REQUEST_ENTITY, @@ -43,8 +43,7 @@ async def trace_request( service_name: str, request: Request, - call_next: Callable[[Request], Awaitable[Response]], -) -> Response: +) -> None: """Construct a request trace with custom dimensions""" request_path = request_to_path(request).strip("/") @@ -101,17 +100,6 @@ async def trace_request( attribute_key="item", attribute_value=item_id ) - # Call next middleware - response = await call_next(request) - - # Include response dimensions in the trace - tracer.add_attribute_to_current_span( - attribute_key=HTTP_STATUS_CODE, attribute_value=response.status_code - ) - return response - else: - return await call_next(request) - collection_id_re = re.compile( r".*/collections/?(?P[a-zA-Z0-9\-\%]+)?(/items/(?P.*))?.*" # noqa @@ -124,7 +112,6 @@ async def _collection_item_from_request( ) -> Tuple[Optional[str], Optional[str]]: """Attempt to get collection and item ids from the request path or querystring.""" url = request.url - path = url.path.strip("/") try: collection_id_match = collection_id_re.match(f"{url}") if collection_id_match: @@ -133,8 +120,6 @@ async def _collection_item_from_request( ) item_id = cast(Optional[str], collection_id_match.group("item_id")) return (collection_id, item_id) - elif path.endswith("/search") or path.endswith("/register"): - return await _parse_collection_from_search(request) else: collection_id = request.query_params.get("collection") # Some endpoints, like preview/, take an `items` parameter, but @@ -161,35 +146,6 @@ def _should_trace_request(request: Request) -> bool: ) -async def _parse_collection_from_search( - request: Request, -) -> Tuple[Optional[str], Optional[str]]: - """ - Parse the collection id from a search request. - - The search endpoint is a bit of a special case. If it's a GET, the collection - and item ids are in the querystring. If it's a POST, the collection and item may - be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body. - """ - - if request.method.lower() == "get": - collection_id = request.query_params.get("collections") - item_id = request.query_params.get("ids") - return (collection_id, item_id) - elif request.method.lower() == "post": - try: - body = await request.json() - if "collections" in body: - return _parse_queryjson(body) - elif "filter" in body: - return _parse_cqljson(body["filter"]) - except json.JSONDecodeError: - logger.warning( - "Unable to parse search body as JSON. Ignoring collection parameter." - ) - return (None, None) - - def _parse_cqljson(cql: dict) -> Tuple[Optional[str], Optional[str]]: """ Parse the collection id from a CQL-JSON filter. @@ -249,7 +205,6 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: result = _iter_cql(item, property_name) if result is not None: return result - # No collection was found return None @@ -260,27 +215,16 @@ def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) collection_id, item_id = parse_collection_from_search( json.loads(search_json), request.method, request.query_params ) - tracer = Tracer( - exporter=exporter, - sampler=ProbabilitySampler(1.0), - ) + parent_span = getattr(request.state, "parent_span", None) - with tracer.span("main") as span: - if ( - hasattr(request.state, "parent_span") - and request.state.parent_span is not None - ): - request.state.parent_span = span - if collection_id is not None: - tracer.add_attribute_to_current_span( - attribute_key="collection", attribute_value=collection_id - ) - if item_id is not None: - tracer.add_attribute_to_current_span( - attribute_key="item", attribute_value=item_id - ) - else: - logger.warning("No 'parent_span' attribute found in request.state") + current_span = execution_context.get_current_span() or parent_span + + if current_span: + current_span.add_attribute("collection", collection_id) + if item_id is not None: + current_span.add_attribute("item", item_id) + else: + logger.warning("No active or parent span available for adding attributes.") def parse_collection_from_search( diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 55ba7950..31f7d9ad 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -3,7 +3,7 @@ import os from typing import Any, Dict -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -15,7 +15,7 @@ from starlette.responses import PlainTextResponse from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler +from pccommon.middleware import TraceMiddleware, add_timeout from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis from pcstac.api import PCStacApi @@ -100,7 +100,13 @@ async def shutdown_event() -> None: await close_db_connection(app) -app.add_exception_handler(Exception, http_exception_handler) +@app.exception_handler(HTTPException) +async def http_exception_handler( + request: Request, exc: HTTPException +) -> PlainTextResponse: + return PlainTextResponse( + str(exc.detail), status_code=exc.status_code, headers=exc.headers + ) @app.exception_handler(StarletteHTTPException) diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index cb32b9dc..fb70c721 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -3,8 +3,9 @@ import os from typing import Dict, List -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException, Request from fastapi.openapi.utils import get_openapi +from fastapi.responses import PlainTextResponse from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet from starlette.middleware.cors import CORSMiddleware @@ -19,7 +20,7 @@ from pccommon.constants import X_REQUEST_ENTITY from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler +from pccommon.middleware import TraceMiddleware, add_timeout from pccommon.openapi import fixup_schema from pctiler.config import get_settings from pctiler.endpoints import ( @@ -84,12 +85,20 @@ app.include_router(health.health_router, tags=["Liveliness/Readiness"]) -app.add_exception_handler(Exception, http_exception_handler) add_timeout(app, settings.request_timeout) add_exception_handlers(app, DEFAULT_STATUS_CODES) add_exception_handlers(app, MOSAIC_STATUS_CODES) +@app.exception_handler(HTTPException) +async def http_exception_handler( + request: Request, exc: HTTPException +) -> PlainTextResponse: + return PlainTextResponse( + str(exc.detail), status_code=exc.status_code, headers=exc.headers + ) + + app.add_middleware(TraceMiddleware, service_name=app.state.service_name) app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600") app.add_middleware(TotalTimeMiddleware) From af0f22743f28637c9497a66d86f8da579fd3f225 Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Thu, 11 Apr 2024 09:43:22 -0400 Subject: [PATCH 07/10] check if collections is set to None --- pccommon/pccommon/tracing.py | 39 +++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 8c873eb1..7fe79405 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -189,22 +189,28 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: Recurse through a CQL or CQL2 filter body, returning the value of the provided property name, if found. Typical usage will be to provide `collection` and `id`. + + :param cql: The CQL filter body as a dictionary. + :param property_name: The name of the property to search for, e.g., "collections" or "ids". + :return: The value(s) for the specified property, if found, otherwise None. """ - for _, v in cql.items(): - if isinstance(v, dict): - result = _iter_cql(v, property_name) + if cql is None: + return None + + if property_name in cql: + return cql[property_name] + + for key, value in cql.items(): + if isinstance(value, dict): + result = _iter_cql(value, property_name) if result is not None: return result - elif isinstance(v, list): - for item in v: + elif isinstance(value, list): + for item in value: if isinstance(item, dict): - if "property" in item: - if item["property"] == property_name: - return v[1] - else: - result = _iter_cql(item, property_name) - if result is not None: - return result + result = _iter_cql(item, property_name) + if result is not None: + return result return None @@ -220,9 +226,10 @@ def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) current_span = execution_context.get_current_span() or parent_span if current_span: - current_span.add_attribute("collection", collection_id) - if item_id is not None: - current_span.add_attribute("item", item_id) + if collection_id is not None: + current_span.add_attribute("collection", collection_id) + if item_id is not None: + current_span.add_attribute("item", item_id) else: logger.warning("No active or parent span available for adding attributes.") @@ -245,7 +252,7 @@ def parse_collection_from_search( return (collection_id, item_id) elif method.lower() == "post": try: - if "collections" in body: + if body.get("collections") is not None: return _parse_queryjson(body) elif "filter" in body: return _parse_cqljson(body["filter"]) From f37e938b4fa89d36077d71ebf3d74b46a3821be7 Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Thu, 11 Apr 2024 10:34:03 -0400 Subject: [PATCH 08/10] add a none check for cql parsing --- pccommon/pccommon/tracing.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 7fe79405..8e28e6c0 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -189,28 +189,25 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]: Recurse through a CQL or CQL2 filter body, returning the value of the provided property name, if found. Typical usage will be to provide `collection` and `id`. - - :param cql: The CQL filter body as a dictionary. - :param property_name: The name of the property to search for, e.g., "collections" or "ids". - :return: The value(s) for the specified property, if found, otherwise None. """ if cql is None: return None - - if property_name in cql: - return cql[property_name] - - for key, value in cql.items(): - if isinstance(value, dict): - result = _iter_cql(value, property_name) + for _, v in cql.items(): + if isinstance(v, dict): + result = _iter_cql(v, property_name) if result is not None: return result - elif isinstance(value, list): - for item in value: + elif isinstance(v, list): + for item in v: if isinstance(item, dict): - result = _iter_cql(item, property_name) - if result is not None: - return result + if "property" in item: + if item["property"] == property_name: + return v[1] + else: + result = _iter_cql(item, property_name) + if result is not None: + return result + # No collection was found return None From 88614ee3212bd6ef3f1c02ca3c01c77694680adf Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Thu, 11 Apr 2024 16:05:39 -0400 Subject: [PATCH 09/10] restore handle http exec --- pccommon/pccommon/middleware.py | 15 ++++++++++++++- pcstac/pcstac/main.py | 12 +++--------- pctiler/pctiler/main.py | 13 +++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index 962e5660..18ffce7a 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any, Callable -from fastapi import Request +from fastapi import HTTPException, Request from fastapi.applications import FastAPI from fastapi.dependencies.utils import ( get_body_field, @@ -16,11 +16,24 @@ from starlette.status import HTTP_504_GATEWAY_TIMEOUT from starlette.types import ASGIApp, Receive, Scope, Send +from pccommon.logging import get_custom_dimensions from pccommon.tracing import trace_request logger = logging.getLogger(__name__) +async def http_exception_handler(request: Request, exc: Exception) -> Any: + logger.exception("Exception when handling request", exc_info=exc) + if isinstance(exc, HTTPException): + raise + else: + logger.exception( + "Exception when handling request", + extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request), + ) + raise + + def with_timeout( timeout_seconds: float, ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 31f7d9ad..55ba7950 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -3,7 +3,7 @@ import os from typing import Any, Dict -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -15,7 +15,7 @@ from starlette.responses import PlainTextResponse from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import TraceMiddleware, add_timeout +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis from pcstac.api import PCStacApi @@ -100,13 +100,7 @@ async def shutdown_event() -> None: await close_db_connection(app) -@app.exception_handler(HTTPException) -async def http_exception_handler( - request: Request, exc: HTTPException -) -> PlainTextResponse: - return PlainTextResponse( - str(exc.detail), status_code=exc.status_code, headers=exc.headers - ) +app.add_exception_handler(Exception, http_exception_handler) @app.exception_handler(StarletteHTTPException) diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index fb70c721..93d1da3b 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -3,9 +3,8 @@ import os from typing import Dict, List -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI from fastapi.openapi.utils import get_openapi -from fastapi.responses import PlainTextResponse from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet from starlette.middleware.cors import CORSMiddleware @@ -20,7 +19,7 @@ from pccommon.constants import X_REQUEST_ENTITY from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import TraceMiddleware, add_timeout +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pctiler.config import get_settings from pctiler.endpoints import ( @@ -90,13 +89,7 @@ add_exception_handlers(app, MOSAIC_STATUS_CODES) -@app.exception_handler(HTTPException) -async def http_exception_handler( - request: Request, exc: HTTPException -) -> PlainTextResponse: - return PlainTextResponse( - str(exc.detail), status_code=exc.status_code, headers=exc.headers - ) +app.add_exception_handler(Exception, http_exception_handler) app.add_middleware(TraceMiddleware, service_name=app.state.service_name) From 538122b25177913d47b9a241b01683aa2ad644de Mon Sep 17 00:00:00 2001 From: maitjoshi Date: Fri, 12 Apr 2024 12:00:24 -0400 Subject: [PATCH 10/10] update handler exc --- pccommon/pccommon/middleware.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index 18ffce7a..b91a3c10 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any, Callable -from fastapi import HTTPException, Request +from fastapi import Request from fastapi.applications import FastAPI from fastapi.dependencies.utils import ( get_body_field, @@ -16,7 +16,6 @@ from starlette.status import HTTP_504_GATEWAY_TIMEOUT from starlette.types import ASGIApp, Receive, Scope, Send -from pccommon.logging import get_custom_dimensions from pccommon.tracing import trace_request logger = logging.getLogger(__name__) @@ -24,14 +23,7 @@ async def http_exception_handler(request: Request, exc: Exception) -> Any: logger.exception("Exception when handling request", exc_info=exc) - if isinstance(exc, HTTPException): - raise - else: - logger.exception( - "Exception when handling request", - extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request), - ) - raise + raise def with_timeout(