Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links #190

Merged
merged 15 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 89 additions & 82 deletions pccommon/pccommon/middleware.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,103 @@
import asyncio
import logging
import time
from typing import Awaitable, Callable
from functools import wraps
from typing import Any, Callable

from fastapi import HTTPException, Request, Response
from fastapi import Request
from fastapi.applications import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from fastapi.dependencies.utils import (
get_body_field,
get_dependant,
get_parameterless_sub_dependant,
)
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute, request_response
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
from starlette.types import Message
from starlette.types import ASGIApp, Receive, Scope, Send

from pccommon.logging import get_custom_dimensions
from pccommon.tracing import trace_request

logger = logging.getLogger(__name__)


async def handle_exceptions(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
try:
return await call_next(request)
except HTTPException:
raise
except Exception as e:
logger.exception(
"Exception when handling request",
extra=get_custom_dimensions({"stackTrace": f"{e}"}, request),
)
raise


class RequestTracingMiddleware(BaseHTTPMiddleware):
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
"""Custom middleware to use opencensus request traces

Middleware implementations that access a Request object directly
will cause subsequent middleware or route handlers to hang. See

https://github.com/tiangolo/fastapi/issues/394

for more details on this implementation.

An alternative approach is to use dependencies on the APIRouter, but
the stac-fast api implementation makes that difficult without having
to override much of the app initialization.
"""

def __init__(self, app: FastAPI, service_name: str):
super().__init__(app)
async def http_exception_handler(request: Request, exc: Exception) -> Any:
logger.exception("Exception when handling request", exc_info=exc)
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
raise


def with_timeout(
timeout_seconds: float,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
logger.debug("Adding timeout to function %s", func.__name__)

@wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
start_time = time.monotonic()
try:
return await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_seconds
)
except asyncio.TimeoutError as e:
process_time = time.monotonic() - start_time
# don't have a request object here to get custom dimensions.
log_dimensions = {
"request_time": process_time,
}
logger.exception(
f"Request timeout {e}",
extra=log_dimensions,
)

ref_id = log_dimensions.get("ref_id")
debug_msg = (
f" Debug information for support: {ref_id}" if ref_id else ""
)

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please"
" try again. If the issue persists, please contact "
"[email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

return inner
else:
return func

return with_timeout_


def add_timeout(app: FastAPI, timeout_seconds: float) -> None:
for route in app.router.routes:
if isinstance(route, APIRoute):
new_endpoint = with_timeout(timeout_seconds)(route.endpoint)
route.endpoint = new_endpoint
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(
depends=depends, path=route.path_format
),
)
route.body_field = get_body_field(
dependant=route.dependant, name=route.unique_id
)
route.app = request_response(route.get_route_handler())


class TraceMiddleware:
def __init__(self, app: ASGIApp, service_name: str):
self.app = app
self.service_name = service_name

async def set_body(self, request: Request) -> None:
receive_ = await request._receive()

async def receive() -> Message:
return receive_

request._receive = receive

async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
await self.set_body(request)
response = await trace_request(self.service_name, request, call_next)
return response


async def timeout_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
timeout: int,
) -> Response:
try:
start_time = time.time()
return await asyncio.wait_for(call_next(request), timeout=timeout)

except asyncio.TimeoutError:
process_time = time.time() - start_time
log_dimensions = get_custom_dimensions({"request_time": process_time}, request)

logger.exception(
"Request timeout",
extra=log_dimensions,
)

ref_id = log_dimensions["custom_dimensions"].get("ref_id")
debug_msg = f"Debug information for support: {ref_id}" if ref_id else ""

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please try again."
" If the issue persists, please contact [email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
request: Request = Request(scope, receive)
await trace_request(self.service_name, request)

await self.app(scope, receive, send)
106 changes: 58 additions & 48 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import json
import logging
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast

from fastapi import Request, Response
import fastapi
from fastapi import Request
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import execution_context
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from starlette.datastructures import QueryParams

from pccommon.config import get_apis_config
from pccommon.constants import (
HTTP_METHOD,
HTTP_PATH,
HTTP_STATUS_CODE,
HTTP_URL,
X_AZURE_REF,
X_REQUEST_ENTITY,
Expand Down Expand Up @@ -41,8 +43,7 @@
async def trace_request(
service_name: str,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
) -> None:
"""Construct a request trace with custom dimensions"""
request_path = request_to_path(request).strip("/")

Expand Down Expand Up @@ -99,17 +100,6 @@ async def trace_request(
attribute_key="item", attribute_value=item_id
)

# Call next middleware
response = await call_next(request)

# Include response dimensions in the trace
tracer.add_attribute_to_current_span(
attribute_key=HTTP_STATUS_CODE, attribute_value=response.status_code
)
return response
else:
return await call_next(request)


collection_id_re = re.compile(
r".*/collections/?(?P<collection_id>[a-zA-Z0-9\-\%]+)?(/items/(?P<item_id>.*))?.*" # noqa
Expand All @@ -122,7 +112,6 @@ async def _collection_item_from_request(
) -> Tuple[Optional[str], Optional[str]]:
"""Attempt to get collection and item ids from the request path or querystring."""
url = request.url
path = url.path.strip("/")
try:
collection_id_match = collection_id_re.match(f"{url}")
if collection_id_match:
Expand All @@ -131,8 +120,6 @@ async def _collection_item_from_request(
)
item_id = cast(Optional[str], collection_id_match.group("item_id"))
return (collection_id, item_id)
elif path.endswith("/search") or path.endswith("/register"):
return await _parse_collection_from_search(request)
else:
collection_id = request.query_params.get("collection")
# Some endpoints, like preview/, take an `items` parameter, but
Expand All @@ -159,35 +146,6 @@ def _should_trace_request(request: Request) -> bool:
)


async def _parse_collection_from_search(
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.

The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""

if request.method.lower() == "get":
collection_id = request.query_params.get("collections")
item_id = request.query_params.get("ids")
return (collection_id, item_id)
elif request.method.lower() == "post":
try:
body = await request.json()
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection parameter."
)
return (None, None)


def _parse_cqljson(cql: dict) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a CQL-JSON filter.
Expand Down Expand Up @@ -232,6 +190,8 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
provided property name, if found. Typical usage will be to provide
`collection` and `id`.
"""
if cql is None:
return None
for _, v in cql.items():
if isinstance(v, dict):
result = _iter_cql(v, property_name)
Expand All @@ -249,3 +209,53 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
return result
# No collection was found
return None


def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None:
"""
Try to add the Collection ID and Item ID from a search to the current span.
"""
collection_id, item_id = parse_collection_from_search(
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
json.loads(search_json), request.method, request.query_params
)
parent_span = getattr(request.state, "parent_span", None)

current_span = execution_context.get_current_span() or parent_span

if current_span:
if collection_id is not None:
current_span.add_attribute("collection", collection_id)
if item_id is not None:
current_span.add_attribute("item", item_id)
else:
logger.warning("No active or parent span available for adding attributes.")


def parse_collection_from_search(
body: dict,
method: str,
query_params: QueryParams,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.

The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""
if method.lower() == "get":
collection_id = query_params.get("collections")
item_id = query_params.get("ids")
return (collection_id, item_id)
elif method.lower() == "post":
try:
if body.get("collections") is not None:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError as e:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection"
f"parameter. {e}"
)
return (None, None)
Loading
Loading