Skip to content

Commit

Permalink
Starlette trace middleware, remove opentelemetry
Browse files Browse the repository at this point in the history
  • Loading branch information
joshimai committed Apr 10, 2024
1 parent 78c065e commit 2fc1ebd
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 39 deletions.
26 changes: 25 additions & 1 deletion pccommon/pccommon/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import wraps
from typing import Any, Callable

from fastapi import HTTPException, Request
from fastapi import HTTPException, Request, Response
from fastapi.applications import FastAPI
from fastapi.dependencies.utils import (
get_body_field,
Expand All @@ -14,8 +14,10 @@
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute, request_response
from starlette.status import HTTP_504_GATEWAY_TIMEOUT
from starlette.types import ASGIApp, Receive, Scope, Send

from pccommon.logging import get_custom_dimensions
from pccommon.tracing import trace_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,3 +99,25 @@ def add_timeout(app: FastAPI, timeout_seconds: float) -> None:
dependant=route.dependant, name=route.unique_id
)
route.app = request_response(route.get_route_handler())


class TraceMiddleware:
def __init__(self, app: ASGIApp, service_name: str):
self.app = app
self.service_name = service_name

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] == "http":
request: Request = Request(scope, receive)

async def call_next(request: Request) -> Response:
# Create a response object to mimic trace_requests call_next
# argument
response = Response()
await self.app(scope, receive, send)
return response

await trace_request(self.service_name, request, call_next)

else:
await self.app(scope, receive, send)
32 changes: 20 additions & 12 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from opentelemetry import trace
from starlette.datastructures import QueryParams

from pccommon.config import get_apis_config
Expand All @@ -28,11 +27,6 @@
logger = logging.getLogger(__name__)


COLLECTION = "spatio.collection"
COLLECTIONS = "spatio.collections"
ITEM = "spatio.item"
ITEMS = "spatio.items"

exporter = (
AzureExporter(
connection_string=(
Expand Down Expand Up @@ -266,13 +260,27 @@ def add_stac_attributes_from_search(search_json: str, request: fastapi.Request)
collection_id, item_id = parse_collection_from_search(
json.loads(search_json), request.method, request.query_params
)
span = trace.get_current_span()

if collection_id is not None:
span.set_attribute(COLLECTIONS, collection_id)
tracer = Tracer(
exporter=exporter,
sampler=ProbabilitySampler(1.0),
)

if item_id is not None:
span.set_attribute(ITEMS, item_id)
with tracer.span("main") as span:
if (
hasattr(request.state, "parent_span")
and request.state.parent_span is not None
):
request.state.parent_span = span
if collection_id is not None:
tracer.add_attribute_to_current_span(
attribute_key="collection", attribute_value=collection_id
)
if item_id is not None:
tracer.add_attribute_to_current_span(
attribute_key="item", attribute_value=item_id
)
else:
logger.warning("No 'parent_span' attribute found in request.state")


def parse_collection_from_search(
Expand Down
2 changes: 0 additions & 2 deletions pccommon/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
"html-sanitizer==2.4",
# Soon available as lxml[html_clean]
"lxml_html_clean==0.1.0",
"opentelemetry-api==1.21.0",
"opentelemetry-sdk==1.21.0",
]

extra_reqs = {
Expand Down
16 changes: 4 additions & 12 deletions pcstac/pcstac/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""FastAPI application using PGStac."""
import logging
import os
from typing import Any, Awaitable, Callable, Dict
from typing import Any, Dict

from fastapi import FastAPI, Request, Response
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
from fastapi.openapi.utils import get_openapi
from fastapi.responses import ORJSONResponse
Expand All @@ -15,10 +15,9 @@
from starlette.responses import PlainTextResponse

from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import add_timeout, http_exception_handler
from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler
from pccommon.openapi import fixup_schema
from pccommon.redis import connect_to_redis
from pccommon.tracing import trace_request
from pcstac.api import PCStacApi
from pcstac.client import PCClient
from pcstac.config import (
Expand Down Expand Up @@ -76,14 +75,7 @@

add_timeout(app, app_settings.request_timeout)


@app.middleware("http")
async def _request_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a trace to all requests."""
return await trace_request(ServiceName.STAC, request, call_next)

app.add_middleware(TraceMiddleware, service_name=app.state.service_name)

# Note: If requests are being sent through an application gateway like
# nginx-ingress, you may need to configure CORS through that system.
Expand Down
16 changes: 4 additions & 12 deletions pctiler/pctiler/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
import logging
import os
from typing import Awaitable, Callable, Dict, List
from typing import Dict, List

from fastapi import FastAPI, Request, Response
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi
from morecantile.defaults import tms as defaultTileMatrices
from morecantile.models import TileMatrixSet
Expand All @@ -19,9 +19,8 @@

from pccommon.constants import X_REQUEST_ENTITY
from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import add_timeout, http_exception_handler
from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler
from pccommon.openapi import fixup_schema
from pccommon.tracing import trace_request
from pctiler.config import get_settings
from pctiler.endpoints import (
configuration,
Expand Down Expand Up @@ -91,14 +90,7 @@
add_exception_handlers(app, MOSAIC_STATUS_CODES)


@app.middleware("http")
async def _request_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a trace to all requests."""
return await trace_request(ServiceName.TILER, request, call_next)


app.add_middleware(TraceMiddleware, service_name=app.state.service_name)
app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600")
app.add_middleware(TotalTimeMiddleware)

Expand Down

0 comments on commit 2fc1ebd

Please sign in to comment.