Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links #190

Merged
merged 15 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 80 additions & 77 deletions pccommon/pccommon/middleware.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,99 @@
import asyncio
import logging
import time
from typing import Awaitable, Callable
from functools import wraps
from typing import Any, Callable

from fastapi import HTTPException, Request, Response
from fastapi import HTTPException, Request
from fastapi.applications import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from fastapi.dependencies.utils import (
get_body_field,
get_dependant,
get_parameterless_sub_dependant,
)
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute, request_response
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
from starlette.types import Message

from pccommon.logging import get_custom_dimensions
from pccommon.tracing import trace_request

logger = logging.getLogger(__name__)


async def handle_exceptions(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
try:
return await call_next(request)
except HTTPException:
async def http_exception_handler(request: Request, exc: Exception) -> Any:
# Log the exception with additional request info if needed
logger.exception("Exception when handling request", exc_info=exc)
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
# Return a custom response for HTTPException
if isinstance(exc, HTTPException):
raise
except Exception as e:
# Handle other exceptions, possibly with a generic response
else:
logger.exception(
"Exception when handling request",
extra=get_custom_dimensions({"stackTrace": f"{e}"}, request),
extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request),
)
raise


class RequestTracingMiddleware(BaseHTTPMiddleware):
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
"""Custom middleware to use opencensus request traces

Middleware implementations that access a Request object directly
will cause subsequent middleware or route handlers to hang. See

https://github.com/tiangolo/fastapi/issues/394

for more details on this implementation.

An alternative approach is to use dependencies on the APIRouter, but
the stac-fast api implementation makes that difficult without having
to override much of the app initialization.
"""

def __init__(self, app: FastAPI, service_name: str):
super().__init__(app)
self.service_name = service_name

async def set_body(self, request: Request) -> None:
receive_ = await request._receive()

async def receive() -> Message:
return receive_

request._receive = receive

async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
await self.set_body(request)
response = await trace_request(self.service_name, request, call_next)
return response


async def timeout_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
timeout: int,
) -> Response:
try:
start_time = time.time()
return await asyncio.wait_for(call_next(request), timeout=timeout)

except asyncio.TimeoutError:
process_time = time.time() - start_time
log_dimensions = get_custom_dimensions({"request_time": process_time}, request)

logger.exception(
"Request timeout",
extra=log_dimensions,
)

ref_id = log_dimensions["custom_dimensions"].get("ref_id")
debug_msg = f"Debug information for support: {ref_id}" if ref_id else ""

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please try again."
" If the issue persists, please contact [email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
def with_timeout(
timeout_seconds: float,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
logger.debug("Adding timeout to function %s", func.__name__)

@wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
start_time = time.monotonic()
try:
return await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_seconds
)
except asyncio.TimeoutError as e:
process_time = time.monotonic() - start_time
# don't have a request object here to get custom dimensions.
log_dimensions = {
"request_time": process_time,
}
logger.exception(
f"Request timeout {e}",
extra=log_dimensions,
)

ref_id = log_dimensions.get("ref_id")
debug_msg = (
f" Debug information for support: {ref_id}" if ref_id else ""
)

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please"
" try again. If the issue persists, please contact "
"[email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

return inner
else:
return func

return with_timeout_


def add_timeout(app: FastAPI, timeout_seconds: float) -> None:
for route in app.router.routes:
if isinstance(route, APIRoute):
new_endpoint = with_timeout(timeout_seconds)(route.endpoint)
route.endpoint = new_endpoint
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(
depends=depends, path=route.path_format
),
)
route.body_field = get_body_field(
dependant=route.dependant, name=route.unique_id
)
route.app = request_response(route.get_route_handler())
54 changes: 54 additions & 0 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast

import fastapi
from fastapi import Request, Response
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from opentelemetry import trace
from starlette.datastructures import QueryParams

from pccommon.config import get_apis_config
from pccommon.constants import (
Expand All @@ -25,6 +28,11 @@
logger = logging.getLogger(__name__)


COLLECTION = "spatio.collection"
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
COLLECTIONS = "spatio.collections"
ITEM = "spatio.item"
ITEMS = "spatio.items"

exporter = (
AzureExporter(
connection_string=(
Expand Down Expand Up @@ -249,3 +257,49 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
return result
# No collection was found
return None


def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None:
"""
Try to add the Collection ID and Item ID from a search to the current span.
"""
collection_id, item_id = parse_collection_from_search(
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
json.loads(search_json), request.method, request.query_params
)
span = trace.get_current_span()
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved

if collection_id is not None:
span.set_attribute(COLLECTIONS, collection_id)

if item_id is not None:
span.set_attribute(ITEMS, item_id)


def parse_collection_from_search(
body: dict,
method: str,
query_params: QueryParams,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.

The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""
if method.lower() == "get":
collection_id = query_params.get("collections")
item_id = query_params.get("ids")
return (collection_id, item_id)
elif method.lower() == "post":
try:
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError as e:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection"
f"parameter. {e}"
)
return (None, None)
2 changes: 2 additions & 0 deletions pccommon/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"html-sanitizer==2.4",
# Soon available as lxml[html_clean]
"lxml_html_clean==0.1.0",
"opentelemetry-api==1.21.0",
"opentelemetry-sdk==1.21.0",
]

extra_reqs = {
Expand Down
90 changes: 16 additions & 74 deletions pccommon/tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import random
from typing import Awaitable, Callable
from typing import Any

import pytest
from fastapi import FastAPI, Request, Response
from fastapi.responses import PlainTextResponse
from fastapi import FastAPI

# from fastapi.responses import PlainTextResponse
from httpx import AsyncClient
from starlette.status import HTTP_200_OK, HTTP_504_GATEWAY_TIMEOUT
from starlette.status import HTTP_504_GATEWAY_TIMEOUT

from pccommon.middleware import timeout_middleware
from pccommon.middleware import add_timeout

TIMEOUT_SECONDS = 2
BASE_URL = "http://test"
Expand All @@ -20,80 +20,22 @@
app.state.service_name = "test"


@app.middleware("http")
async def _timeout_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a timeout to all requests."""
return await timeout_middleware(request, call_next, timeout=TIMEOUT_SECONDS)


# Test endpoint to sleep for a configurable amount of time, which may exceed the
# timeout middleware setting
@app.get("/sleep", response_class=PlainTextResponse)
async def route_for_test(t: int) -> str:
await asyncio.sleep(t)
return "Done"


# Test endpoint to sleep and confirm that the task is cancelled after the timeout
@app.get("/cancel", response_class=PlainTextResponse)
async def route_for_cancel_test(t: int) -> str:
for i in range(t):
await asyncio.sleep(1)
if i > TIMEOUT_SECONDS:
raise Exception("Task should have been cancelled")

return "Done"


# Test middleware
# ===============


async def success_response(client: AsyncClient, timeout: int) -> None:
print("making request")
response = await client.get("/sleep", params={"t": timeout})
assert response.status_code == HTTP_200_OK
assert response.text == "Done"
@app.get("/asleep")
async def asleep() -> Any:
await asyncio.sleep(1)
return {}


async def timeout_response(client: AsyncClient, timeout: int) -> None:
response = await client.get("/sleep", params={"t": timeout})
assert response.status_code == HTTP_504_GATEWAY_TIMEOUT


@pytest.mark.asyncio
async def test_timeout() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
await timeout_response(client, 10)

# Run this after registering the routes

@pytest.mark.asyncio
async def test_no_timeout() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
await success_response(client, 1)
add_timeout(app, timeout_seconds=0.001)


@pytest.mark.asyncio
async def test_multiple_requests() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
timeout_tasks = []
for _ in range(100):
t = TIMEOUT_SECONDS + random.randint(1, 10)
timeout_tasks.append(asyncio.ensure_future(timeout_response(client, t)))

await asyncio.gather(*timeout_tasks)

success_tasks = []
for _ in range(100):
t = TIMEOUT_SECONDS - 1
success_tasks.append(asyncio.ensure_future(success_response(client, t)))
async def test_add_timeout() -> None:

await asyncio.gather(*success_tasks)
client = AsyncClient(app=app, base_url=BASE_URL)

response = await client.get("/asleep")

@pytest.mark.asyncio
async def test_request_cancelled() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
await client.get("/cancel", params={"t": 10})
assert response.status_code == HTTP_504_GATEWAY_TIMEOUT
3 changes: 3 additions & 0 deletions pcstac/pcstac/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pccommon.constants import DEFAULT_COLLECTION_REGION
from pccommon.logging import get_custom_dimensions
from pccommon.redis import back_pressure, cached_result, rate_limit
from pccommon.tracing import add_stac_attributes_from_search
from pcstac.config import API_DESCRIPTION, API_LANDING_PAGE_ID, API_TITLE, get_settings
from pcstac.contants import (
CACHE_KEY_COLLECTION,
Expand Down Expand Up @@ -227,6 +228,8 @@ async def _fetch() -> ItemCollection:
return item_collection

search_json = search_request.json()
add_stac_attributes_from_search(search_json, request)

logger.info(
"STAC: Item search body",
extra=get_custom_dimensions({"search_body": search_json}, request),
Expand Down
Loading
Loading