Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links #190

Merged
merged 15 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 3 additions & 28 deletions pccommon/pccommon/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import wraps
from typing import Any, Callable

from fastapi import HTTPException, Request, Response
from fastapi import Request
from fastapi.applications import FastAPI
from fastapi.dependencies.utils import (
get_body_field,
Expand All @@ -16,27 +16,11 @@
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
from starlette.types import ASGIApp, Receive, Scope, Send

from pccommon.logging import get_custom_dimensions
from pccommon.tracing import trace_request

logger = logging.getLogger(__name__)


async def http_exception_handler(request: Request, exc: Exception) -> Any:
# Log the exception with additional request info if needed
logger.exception("Exception when handling request", exc_info=exc)
# Return a custom response for HTTPException
if isinstance(exc, HTTPException):
raise
# Handle other exceptions, possibly with a generic response
else:
logger.exception(
"Exception when handling request",
extra=get_custom_dimensions({"stackTrace": f"{exc}"}, request),
)
raise


def with_timeout(
timeout_seconds: float,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
Expand Down Expand Up @@ -109,15 +93,6 @@ def __init__(self, app: ASGIApp, service_name: str):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
request: Request = Request(scope, receive)
await trace_request(self.service_name, request)

async def call_next(request: Request) -> Response:
# Create a response object to mimic trace_requests call_next
# argument
response = Response()
await self.app(scope, receive, send)
return response

await trace_request(self.service_name, request, call_next)

else:
await self.app(scope, receive, send)
await self.app(scope, receive, send)
82 changes: 13 additions & 69 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import logging
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast

import fastapi
from fastapi import Request, Response
from fastapi import Request
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace import execution_context
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
Expand All @@ -15,7 +16,6 @@
from pccommon.constants import (
HTTP_METHOD,
HTTP_PATH,
HTTP_STATUS_CODE,
HTTP_URL,
X_AZURE_REF,
X_REQUEST_ENTITY,
Expand Down Expand Up @@ -43,8 +43,7 @@
async def trace_request(
service_name: str,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
) -> None:
"""Construct a request trace with custom dimensions"""
request_path = request_to_path(request).strip("/")

Expand Down Expand Up @@ -101,17 +100,6 @@ async def trace_request(
attribute_key="item", attribute_value=item_id
)

# Call next middleware
response = await call_next(request)

# Include response dimensions in the trace
tracer.add_attribute_to_current_span(
attribute_key=HTTP_STATUS_CODE, attribute_value=response.status_code
)
return response
else:
return await call_next(request)


collection_id_re = re.compile(
r".*/collections/?(?P<collection_id>[a-zA-Z0-9\-\%]+)?(/items/(?P<item_id>.*))?.*" # noqa
Expand All @@ -124,7 +112,6 @@ async def _collection_item_from_request(
) -> Tuple[Optional[str], Optional[str]]:
"""Attempt to get collection and item ids from the request path or querystring."""
url = request.url
path = url.path.strip("/")
try:
collection_id_match = collection_id_re.match(f"{url}")
if collection_id_match:
Expand All @@ -133,8 +120,6 @@ async def _collection_item_from_request(
)
item_id = cast(Optional[str], collection_id_match.group("item_id"))
return (collection_id, item_id)
elif path.endswith("/search") or path.endswith("/register"):
return await _parse_collection_from_search(request)
else:
collection_id = request.query_params.get("collection")
# Some endpoints, like preview/, take an `items` parameter, but
Expand All @@ -161,35 +146,6 @@ def _should_trace_request(request: Request) -> bool:
)


async def _parse_collection_from_search(
request: Request,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.

The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""

if request.method.lower() == "get":
collection_id = request.query_params.get("collections")
item_id = request.query_params.get("ids")
return (collection_id, item_id)
elif request.method.lower() == "post":
try:
body = await request.json()
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection parameter."
)
return (None, None)


def _parse_cqljson(cql: dict) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a CQL-JSON filter.
Expand Down Expand Up @@ -249,7 +205,6 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
result = _iter_cql(item, property_name)
if result is not None:
return result
# No collection was found
return None


Expand All @@ -260,27 +215,16 @@ def add_stac_attributes_from_search(search_json: str, request: fastapi.Request)
collection_id, item_id = parse_collection_from_search(
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
json.loads(search_json), request.method, request.query_params
)
tracer = Tracer(
exporter=exporter,
sampler=ProbabilitySampler(1.0),
)
parent_span = getattr(request.state, "parent_span", None)

with tracer.span("main") as span:
if (
hasattr(request.state, "parent_span")
and request.state.parent_span is not None
):
request.state.parent_span = span
if collection_id is not None:
tracer.add_attribute_to_current_span(
attribute_key="collection", attribute_value=collection_id
)
if item_id is not None:
tracer.add_attribute_to_current_span(
attribute_key="item", attribute_value=item_id
)
else:
logger.warning("No 'parent_span' attribute found in request.state")
current_span = execution_context.get_current_span() or parent_span

if current_span:
current_span.add_attribute("collection", collection_id)
if item_id is not None:
current_span.add_attribute("item", item_id)
else:
logger.warning("No active or parent span available for adding attributes.")


def parse_collection_from_search(
Expand Down
12 changes: 9 additions & 3 deletions pcstac/pcstac/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from typing import Any, Dict

from fastapi import FastAPI, Request
from fastapi import FastAPI, HTTPException, Request
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
from fastapi.openapi.utils import get_openapi
from fastapi.responses import ORJSONResponse
Expand All @@ -15,7 +15,7 @@
from starlette.responses import PlainTextResponse

from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler
from pccommon.middleware import TraceMiddleware, add_timeout
from pccommon.openapi import fixup_schema
from pccommon.redis import connect_to_redis
from pcstac.api import PCStacApi
Expand Down Expand Up @@ -100,7 +100,13 @@ async def shutdown_event() -> None:
await close_db_connection(app)


app.add_exception_handler(Exception, http_exception_handler)
@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exc: HTTPException
) -> PlainTextResponse:
return PlainTextResponse(
str(exc.detail), status_code=exc.status_code, headers=exc.headers
)


@app.exception_handler(StarletteHTTPException)
Expand Down
15 changes: 12 additions & 3 deletions pctiler/pctiler/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import os
from typing import Dict, List

from fastapi import FastAPI
from fastapi import FastAPI, HTTPException, Request
from fastapi.openapi.utils import get_openapi
from fastapi.responses import PlainTextResponse
from morecantile.defaults import tms as defaultTileMatrices
from morecantile.models import TileMatrixSet
from starlette.middleware.cors import CORSMiddleware
Expand All @@ -19,7 +20,7 @@

from pccommon.constants import X_REQUEST_ENTITY
from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler
from pccommon.middleware import TraceMiddleware, add_timeout
from pccommon.openapi import fixup_schema
from pctiler.config import get_settings
from pctiler.endpoints import (
Expand Down Expand Up @@ -84,12 +85,20 @@

app.include_router(health.health_router, tags=["Liveliness/Readiness"])

app.add_exception_handler(Exception, http_exception_handler)
add_timeout(app, settings.request_timeout)
add_exception_handlers(app, DEFAULT_STATUS_CODES)
add_exception_handlers(app, MOSAIC_STATUS_CODES)


@app.exception_handler(HTTPException)
async def http_exception_handler(
request: Request, exc: HTTPException
) -> PlainTextResponse:
return PlainTextResponse(
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
str(exc.detail), status_code=exc.status_code, headers=exc.headers
)


app.add_middleware(TraceMiddleware, service_name=app.state.service_name)
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600")
app.add_middleware(TotalTimeMiddleware)
Expand Down
Loading