Skip to content

Commit

Permalink
Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links (
Browse files Browse the repository at this point in the history
…#190)

* Remove BaseHTTPMiddlewares ,  Ensure origin host is used in STAC links
* Starlette trace middleware, remove opentelemetry
* remove duplicate search tracing, use execution_context as span
---------

Co-authored-by: Gustavo Hidalgo <[email protected]>
  • Loading branch information
joshimai and ghidalgo3 authored Apr 12, 2024
1 parent 6887ffc commit f983bea
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 263 deletions.
171 changes: 89 additions & 82 deletions pccommon/pccommon/middleware.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,103 @@
import asyncio
import logging
import time
from typing import Awaitable, Callable
from functools import wraps
from typing import Any, Callable

from fastapi import HTTPException, Request, Response
from fastapi import Request
from fastapi.applications import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from fastapi.dependencies.utils import (
get_body_field,
get_dependant,
get_parameterless_sub_dependant,
)
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute, request_response
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
from starlette.types import Message
from starlette.types import ASGIApp, Receive, Scope, Send

from pccommon.logging import get_custom_dimensions
from pccommon.tracing import trace_request

logger = logging.getLogger(__name__)


async def handle_exceptions(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
try:
return await call_next(request)
except HTTPException:
raise
except Exception as e:
logger.exception(
"Exception when handling request",
extra=get_custom_dimensions({"stackTrace": f"{e}"}, request),
)
raise


class RequestTracingMiddleware(BaseHTTPMiddleware):
"""Custom middleware to use opencensus request traces
Middleware implementations that access a Request object directly
will cause subsequent middleware or route handlers to hang. See
https://github.com/tiangolo/fastapi/issues/394
for more details on this implementation.
An alternative approach is to use dependencies on the APIRouter, but
the stac-fast api implementation makes that difficult without having
to override much of the app initialization.
"""

def __init__(self, app: FastAPI, service_name: str):
super().__init__(app)
async def http_exception_handler(request: Request, exc: Exception) -> Any:
logger.exception("Exception when handling request", exc_info=exc)
raise


def with_timeout(
timeout_seconds: float,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
logger.debug("Adding timeout to function %s", func.__name__)

@wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
start_time = time.monotonic()
try:
return await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_seconds
)
except asyncio.TimeoutError as e:
process_time = time.monotonic() - start_time
# don't have a request object here to get custom dimensions.
log_dimensions = {
"request_time": process_time,
}
logger.exception(
f"Request timeout {e}",
extra=log_dimensions,
)

ref_id = log_dimensions.get("ref_id")
debug_msg = (
f" Debug information for support: {ref_id}" if ref_id else ""
)

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please"
" try again. If the issue persists, please contact "
"[email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

return inner
else:
return func

return with_timeout_


def add_timeout(app: FastAPI, timeout_seconds: float) -> None:
for route in app.router.routes:
if isinstance(route, APIRoute):
new_endpoint = with_timeout(timeout_seconds)(route.endpoint)
route.endpoint = new_endpoint
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(
depends=depends, path=route.path_format
),
)
route.body_field = get_body_field(
dependant=route.dependant, name=route.unique_id
)
route.app = request_response(route.get_route_handler())


class TraceMiddleware:
def __init__(self, app: ASGIApp, service_name: str):
self.app = app
self.service_name = service_name

async def set_body(self, request: Request) -> None:
receive_ = await request._receive()

async def receive() -> Message:
return receive_

request._receive = receive

async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
await self.set_body(request)
response = await trace_request(self.service_name, request, call_next)
return response


async def timeout_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
timeout: int,
) -> Response:
try:
start_time = time.time()
return await asyncio.wait_for(call_next(request), timeout=timeout)

except asyncio.TimeoutError:
process_time = time.time() - start_time
log_dimensions = get_custom_dimensions({"request_time": process_time}, request)

logger.exception(
"Request timeout",
extra=log_dimensions,
)

ref_id = log_dimensions["custom_dimensions"].get("ref_id")
debug_msg = f"Debug information for support: {ref_id}" if ref_id else ""

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please try again."
" If the issue persists, please contact [email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
request: Request = Request(scope, receive)
await trace_request(self.service_name, request)

await self.app(scope, receive, send)
106 changes: 58 additions & 48 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import json
import logging
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast

from fastapi import Request, Response
import fastapi
from fastapi import Request
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import execution_context
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from starlette.datastructures import QueryParams

from pccommon.config import get_apis_config
from pccommon.constants import (
HTTP_METHOD,
HTTP_PATH,
HTTP_STATUS_CODE,
HTTP_URL,
X_AZURE_REF,
X_REQUEST_ENTITY,
Expand Down Expand Up @@ -41,8 +43,7 @@
async def trace_request(
service_name: str,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
) -> None:
"""Construct a request trace with custom dimensions"""
request_path = request_to_path(request).strip("/")

Expand Down Expand Up @@ -99,17 +100,6 @@ async def trace_request(
attribute_key="item", attribute_value=item_id
)

# Call next middleware
response = await call_next(request)

# Include response dimensions in the trace
tracer.add_attribute_to_current_span(
attribute_key=HTTP_STATUS_CODE, attribute_value=response.status_code
)
return response
else:
return await call_next(request)


collection_id_re = re.compile(
r".*/collections/?(?P<collection_id>[a-zA-Z0-9\-\%]+)?(/items/(?P<item_id>.*))?.*" # noqa
Expand All @@ -122,7 +112,6 @@ async def _collection_item_from_request(
) -> Tuple[Optional[str], Optional[str]]:
"""Attempt to get collection and item ids from the request path or querystring."""
url = request.url
path = url.path.strip("/")
try:
collection_id_match = collection_id_re.match(f"{url}")
if collection_id_match:
Expand All @@ -131,8 +120,6 @@ async def _collection_item_from_request(
)
item_id = cast(Optional[str], collection_id_match.group("item_id"))
return (collection_id, item_id)
elif path.endswith("/search") or path.endswith("/register"):
return await _parse_collection_from_search(request)
else:
collection_id = request.query_params.get("collection")
# Some endpoints, like preview/, take an `items` parameter, but
Expand All @@ -159,35 +146,6 @@ def _should_trace_request(request: Request) -> bool:
)


async def _parse_collection_from_search(
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.
The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""

if request.method.lower() == "get":
collection_id = request.query_params.get("collections")
item_id = request.query_params.get("ids")
return (collection_id, item_id)
elif request.method.lower() == "post":
try:
body = await request.json()
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection parameter."
)
return (None, None)


def _parse_cqljson(cql: dict) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a CQL-JSON filter.
Expand Down Expand Up @@ -232,6 +190,8 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
provided property name, if found. Typical usage will be to provide
`collection` and `id`.
"""
if cql is None:
return None
for _, v in cql.items():
if isinstance(v, dict):
result = _iter_cql(v, property_name)
Expand All @@ -249,3 +209,53 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
return result
# No collection was found
return None


def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None:
"""
Try to add the Collection ID and Item ID from a search to the current span.
"""
collection_id, item_id = parse_collection_from_search(
json.loads(search_json), request.method, request.query_params
)
parent_span = getattr(request.state, "parent_span", None)

current_span = execution_context.get_current_span() or parent_span

if current_span:
if collection_id is not None:
current_span.add_attribute("collection", collection_id)
if item_id is not None:
current_span.add_attribute("item", item_id)
else:
logger.warning("No active or parent span available for adding attributes.")


def parse_collection_from_search(
body: dict,
method: str,
query_params: QueryParams,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.
The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""
if method.lower() == "get":
collection_id = query_params.get("collections")
item_id = query_params.get("ids")
return (collection_id, item_id)
elif method.lower() == "post":
try:
if body.get("collections") is not None:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError as e:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection"
f"parameter. {e}"
)
return (None, None)
Loading

0 comments on commit f983bea

Please sign in to comment.