Skip to content

Commit

Permalink
add tracing middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
joshimai committed Apr 9, 2024
1 parent a311b33 commit 78c065e
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
54 changes: 54 additions & 0 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast

import fastapi
from fastapi import Request, Response
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from opentelemetry import trace
from starlette.datastructures import QueryParams

from pccommon.config import get_apis_config
from pccommon.constants import (
Expand All @@ -25,6 +28,11 @@
logger = logging.getLogger(__name__)


COLLECTION = "spatio.collection"
COLLECTIONS = "spatio.collections"
ITEM = "spatio.item"
ITEMS = "spatio.items"

exporter = (
AzureExporter(
connection_string=(
Expand Down Expand Up @@ -249,3 +257,49 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
return result
# No collection was found
return None


def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None:
"""
Try to add the Collection ID and Item ID from a search to the current span.
"""
collection_id, item_id = parse_collection_from_search(
json.loads(search_json), request.method, request.query_params
)
span = trace.get_current_span()

if collection_id is not None:
span.set_attribute(COLLECTIONS, collection_id)

if item_id is not None:
span.set_attribute(ITEMS, item_id)


def parse_collection_from_search(
body: dict,
method: str,
query_params: QueryParams,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.
The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""
if method.lower() == "get":
collection_id = query_params.get("collections")
item_id = query_params.get("ids")
return (collection_id, item_id)
elif method.lower() == "post":
try:
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError as e:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection"
f"parameter. {e}"
)
return (None, None)
2 changes: 2 additions & 0 deletions pccommon/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"html-sanitizer==2.4",
# Soon available as lxml[html_clean]
"lxml_html_clean==0.1.0",
"opentelemetry-api==1.21.0",
"opentelemetry-sdk==1.21.0",
]

extra_reqs = {
Expand Down
3 changes: 3 additions & 0 deletions pcstac/pcstac/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pccommon.constants import DEFAULT_COLLECTION_REGION
from pccommon.logging import get_custom_dimensions
from pccommon.redis import back_pressure, cached_result, rate_limit
from pccommon.tracing import add_stac_attributes_from_search
from pcstac.config import API_DESCRIPTION, API_LANDING_PAGE_ID, API_TITLE, get_settings
from pcstac.contants import (
CACHE_KEY_COLLECTION,
Expand Down Expand Up @@ -227,6 +228,8 @@ async def _fetch() -> ItemCollection:
return item_collection

search_json = search_request.json()
add_stac_attributes_from_search(search_json, request)

logger.info(
"STAC: Item search body",
extra=get_custom_dimensions({"search_body": search_json}, request),
Expand Down
14 changes: 12 additions & 2 deletions pcstac/pcstac/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""FastAPI application using PGStac."""
import logging
import os
from typing import Any, Dict
from typing import Any, Awaitable, Callable, Dict

from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Response
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
from fastapi.openapi.utils import get_openapi
from fastapi.responses import ORJSONResponse
Expand All @@ -18,6 +18,7 @@
from pccommon.middleware import add_timeout, http_exception_handler
from pccommon.openapi import fixup_schema
from pccommon.redis import connect_to_redis
from pccommon.tracing import trace_request
from pcstac.api import PCStacApi
from pcstac.client import PCClient
from pcstac.config import (
Expand Down Expand Up @@ -75,6 +76,15 @@

add_timeout(app, app_settings.request_timeout)


@app.middleware("http")
async def _request_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a trace to all requests."""
return await trace_request(ServiceName.STAC, request, call_next)


# Note: If requests are being sent through an application gateway like
# nginx-ingress, you may need to configure CORS through that system.
app.add_middleware(
Expand Down
14 changes: 12 additions & 2 deletions pctiler/pctiler/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
import logging
import os
from typing import Dict, List
from typing import Awaitable, Callable, Dict, List

from fastapi import FastAPI
from fastapi import FastAPI, Request, Response
from fastapi.openapi.utils import get_openapi
from morecantile.defaults import tms as defaultTileMatrices
from morecantile.models import TileMatrixSet
Expand All @@ -21,6 +21,7 @@
from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import add_timeout, http_exception_handler
from pccommon.openapi import fixup_schema
from pccommon.tracing import trace_request
from pctiler.config import get_settings
from pctiler.endpoints import (
configuration,
Expand Down Expand Up @@ -89,6 +90,15 @@
add_exception_handlers(app, DEFAULT_STATUS_CODES)
add_exception_handlers(app, MOSAIC_STATUS_CODES)


@app.middleware("http")
async def _request_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a trace to all requests."""
return await trace_request(ServiceName.TILER, request, call_next)


app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600")
app.add_middleware(TotalTimeMiddleware)

Expand Down

0 comments on commit 78c065e

Please sign in to comment.