Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links #190

Merged
merged 15 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 80 additions & 77 deletions pccommon/pccommon/middleware.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,99 @@
import asyncio
import logging
import time
from typing import Awaitable, Callable
from functools import wraps
from typing import Any, Callable

from fastapi import HTTPException, Request, Response
from fastapi import HTTPException, Request
from fastapi.applications import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from fastapi.dependencies.utils import (
get_body_field,
get_dependant,
get_parameterless_sub_dependant,
)
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute, request_response
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
from starlette.types import Message

from pccommon.logging import get_custom_dimensions
from pccommon.tracing import trace_request

logger = logging.getLogger(__name__)


async def handle_exceptions(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
try:
return await call_next(request)
except HTTPException:
async def http_exception_handler(request: Request, exc: Exception) -> Any:
# Log the exception with additional request info if needed
logger.exception("Exception when handling request", exc_info=exc)
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
# Return a custom response for HTTPException
if isinstance(exc, HTTPException):
raise
except Exception as e:
# Handle other exceptions, possibly with a generic response
else:
logger.exception(
"Exception when handling request",
extra=get_custom_dimensions({"stackTrace": f"{e}"}, request),
extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request),
)
raise


class RequestTracingMiddleware(BaseHTTPMiddleware):
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
"""Custom middleware to use opencensus request traces

Middleware implementations that access a Request object directly
will cause subsequent middleware or route handlers to hang. See

https://github.com/tiangolo/fastapi/issues/394

for more details on this implementation.

An alternative approach is to use dependencies on the APIRouter, but
the stac-fast api implementation makes that difficult without having
to override much of the app initialization.
"""

def __init__(self, app: FastAPI, service_name: str):
super().__init__(app)
self.service_name = service_name

async def set_body(self, request: Request) -> None:
receive_ = await request._receive()

async def receive() -> Message:
return receive_

request._receive = receive

async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
await self.set_body(request)
response = await trace_request(self.service_name, request, call_next)
return response


async def timeout_middleware(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
timeout: int,
) -> Response:
try:
start_time = time.time()
return await asyncio.wait_for(call_next(request), timeout=timeout)

except asyncio.TimeoutError:
process_time = time.time() - start_time
log_dimensions = get_custom_dimensions({"request_time": process_time}, request)

logger.exception(
"Request timeout",
extra=log_dimensions,
)

ref_id = log_dimensions["custom_dimensions"].get("ref_id")
debug_msg = f"Debug information for support: {ref_id}" if ref_id else ""

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please try again."
" If the issue persists, please contact [email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)
def with_timeout(
timeout_seconds: float,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def with_timeout_(func: Callable[..., Any]) -> Callable[..., Any]:
if asyncio.iscoroutinefunction(func):
logger.debug("Adding timeout to function %s", func.__name__)

@wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
start_time = time.monotonic()
try:
return await asyncio.wait_for(
func(*args, **kwargs), timeout=timeout_seconds
)
except asyncio.TimeoutError as e:
process_time = time.monotonic() - start_time
# don't have a request object here to get custom dimensions.
log_dimensions = {
"request_time": process_time,
}
logger.exception(
f"Request timeout {e}",
extra=log_dimensions,
)

ref_id = log_dimensions.get("ref_id")
debug_msg = (
f" Debug information for support: {ref_id}" if ref_id else ""
)

return PlainTextResponse(
f"The request exceeded the maximum allowed time, please"
" try again. If the issue persists, please contact "
"[email protected]."
f"\n\n{debug_msg}",
status_code=HTTP_504_GATEWAY_TIMEOUT,
)

return inner
else:
return func

return with_timeout_


def add_timeout(app: FastAPI, timeout_seconds: float) -> None:
for route in app.router.routes:
if isinstance(route, APIRoute):
new_endpoint = with_timeout(timeout_seconds)(route.endpoint)
route.endpoint = new_endpoint
route.dependant = get_dependant(path=route.path_format, call=route.endpoint)
for depends in route.dependencies[::-1]:
route.dependant.dependencies.insert(
0,
get_parameterless_sub_dependant(
depends=depends, path=route.path_format
),
)
route.body_field = get_body_field(
dependant=route.dependant, name=route.unique_id
)
route.app = request_response(route.get_route_handler())
90 changes: 16 additions & 74 deletions pccommon/tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import asyncio
import random
from typing import Awaitable, Callable
from typing import Any

import pytest
from fastapi import FastAPI, Request, Response
from fastapi.responses import PlainTextResponse
from fastapi import FastAPI

# from fastapi.responses import PlainTextResponse
from httpx import AsyncClient
from starlette.status import HTTP_200_OK, HTTP_504_GATEWAY_TIMEOUT
from starlette.status import HTTP_504_GATEWAY_TIMEOUT

from pccommon.middleware import timeout_middleware
from pccommon.middleware import add_timeout

TIMEOUT_SECONDS = 2
BASE_URL = "http://test"
Expand All @@ -20,80 +20,22 @@
app.state.service_name = "test"


@app.middleware("http")
async def _timeout_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a timeout to all requests."""
return await timeout_middleware(request, call_next, timeout=TIMEOUT_SECONDS)


# Test endpoint to sleep for a configurable amount of time, which may exceed the
# timeout middleware setting
@app.get("/sleep", response_class=PlainTextResponse)
async def route_for_test(t: int) -> str:
await asyncio.sleep(t)
return "Done"


# Test endpoint to sleep and confirm that the task is cancelled after the timeout
@app.get("/cancel", response_class=PlainTextResponse)
async def route_for_cancel_test(t: int) -> str:
for i in range(t):
await asyncio.sleep(1)
if i > TIMEOUT_SECONDS:
raise Exception("Task should have been cancelled")

return "Done"


# Test middleware
# ===============


async def success_response(client: AsyncClient, timeout: int) -> None:
print("making request")
response = await client.get("/sleep", params={"t": timeout})
assert response.status_code == HTTP_200_OK
assert response.text == "Done"
@app.get("/asleep")
async def asleep() -> Any:
await asyncio.sleep(1)
return {}


async def timeout_response(client: AsyncClient, timeout: int) -> None:
response = await client.get("/sleep", params={"t": timeout})
assert response.status_code == HTTP_504_GATEWAY_TIMEOUT


@pytest.mark.asyncio
async def test_timeout() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
await timeout_response(client, 10)

# Run this after registering the routes

@pytest.mark.asyncio
async def test_no_timeout() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
await success_response(client, 1)
add_timeout(app, timeout_seconds=0.001)


@pytest.mark.asyncio
async def test_multiple_requests() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
timeout_tasks = []
for _ in range(100):
t = TIMEOUT_SECONDS + random.randint(1, 10)
timeout_tasks.append(asyncio.ensure_future(timeout_response(client, t)))

await asyncio.gather(*timeout_tasks)

success_tasks = []
for _ in range(100):
t = TIMEOUT_SECONDS - 1
success_tasks.append(asyncio.ensure_future(success_response(client, t)))
async def test_add_timeout() -> None:

await asyncio.gather(*success_tasks)
client = AsyncClient(app=app, base_url=BASE_URL)

response = await client.get("/asleep")

@pytest.mark.asyncio
async def test_request_cancelled() -> None:
async with AsyncClient(app=app, base_url=BASE_URL) as client:
await client.get("/cancel", params={"t": 10})
assert response.status_code == HTTP_504_GATEWAY_TIMEOUT
2 changes: 1 addition & 1 deletion pcstac/pcstac/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class Settings(BaseSettings):
api_version: str = f"v{API_VERSION}"
rate_limits: RateLimits = RateLimits()
back_pressures: BackPressures = BackPressures()
request_timout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30)
request_timeout: int = Field(env=REQUEST_TIMEOUT_ENV_VAR, default=30)

def get_tiler_href(self, request: Request) -> str:
"""Generates the tiler HREF.
Expand Down
41 changes: 10 additions & 31 deletions pcstac/pcstac/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""FastAPI application using PGStac."""
import logging
import os
from typing import Any, Awaitable, Callable, Dict
from typing import Any, Dict

from fastapi import FastAPI, HTTPException, Request, Response
from brotli_asgi import BrotliMiddleware
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
from fastapi.openapi.utils import get_openapi
from fastapi.responses import ORJSONResponse
Expand All @@ -16,9 +17,8 @@

from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import (
RequestTracingMiddleware,
handle_exceptions,
timeout_middleware,
add_timeout,
http_exception_handler,
)
from pccommon.openapi import fixup_schema
from pccommon.redis import connect_to_redis
Expand All @@ -32,6 +32,7 @@
get_settings,
)
from pcstac.errors import PC_DEFAULT_STATUS_CODES
from pcstac.middleware import ProxyHeaderHostMiddleware
from pcstac.search import PCSearch, PCSearchGetRequest, RedisBaseItemCache

DEBUG: bool = os.getenv("DEBUG") == "TRUE" or False
Expand Down Expand Up @@ -70,13 +71,16 @@
search_get_request_model=search_get_request_model,
search_post_request_model=search_post_request_model,
response_class=ORJSONResponse,
middlewares=[BrotliMiddleware, ProxyHeaderHostMiddleware],
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
exceptions={**DEFAULT_STATUS_CODES, **PC_DEFAULT_STATUS_CODES},
)

app: FastAPI = api.app

app.state.service_name = ServiceName.STAC

add_timeout(app, app_settings.request_timeout)

# Note: If requests are being sent through an application gateway like
# nginx-ingress, you may need to configure CORS through that system.
app.add_middleware(
Expand All @@ -86,25 +90,6 @@
allow_headers=["*"],
)

app.add_middleware(RequestTracingMiddleware, service_name=ServiceName.STAC)


@app.middleware("http")
async def _timeout_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a timeout to all requests."""
return await timeout_middleware(
request, call_next, timeout=app_settings.request_timout
)


@app.middleware("http")
async def _handle_exceptions(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await handle_exceptions(request, call_next)


@app.on_event("startup")
async def startup_event() -> None:
Expand All @@ -119,13 +104,7 @@ async def shutdown_event() -> None:
await close_db_connection(app)


@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exc: HTTPException
) -> PlainTextResponse:
return PlainTextResponse(
str(exc.detail), status_code=exc.status_code, headers=exc.headers
)
app.add_exception_handler(Exception, http_exception_handler)


@app.exception_handler(StarletteHTTPException)
Expand Down
Loading