diff --git a/pccommon/pccommon/middleware.py b/pccommon/pccommon/middleware.py index f04a3ac1..f670563e 100644 --- a/pccommon/pccommon/middleware.py +++ b/pccommon/pccommon/middleware.py @@ -4,7 +4,7 @@ from functools import wraps from typing import Any, Callable -from fastapi import HTTPException, Request +from fastapi import HTTPException, Request, Response from fastapi.applications import FastAPI from fastapi.dependencies.utils import ( get_body_field, @@ -14,8 +14,10 @@ from fastapi.responses import PlainTextResponse from fastapi.routing import APIRoute, request_response from starlette.status import HTTP_504_GATEWAY_TIMEOUT +from starlette.types import ASGIApp, Receive, Scope, Send from pccommon.logging import get_custom_dimensions +from pccommon.tracing import trace_request logger = logging.getLogger(__name__) @@ -97,3 +99,25 @@ def add_timeout(app: FastAPI, timeout_seconds: float) -> None: dependant=route.dependant, name=route.unique_id ) route.app = request_response(route.get_route_handler()) + + +class TraceMiddleware: + def __init__(self, app: ASGIApp, service_name: str): + self.app = app + self.service_name = service_name + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + request: Request = Request(scope, receive) + + async def call_next(request: Request) -> Response: + # Create a response object to mimic trace_requests call_next + # argument + response = Response() + await self.app(scope, receive, send) + return response + + await trace_request(self.service_name, request, call_next) + + else: + await self.app(scope, receive, send) diff --git a/pccommon/pccommon/tracing.py b/pccommon/pccommon/tracing.py index 13b5b75f..7da7ec8a 100644 --- a/pccommon/pccommon/tracing.py +++ b/pccommon/pccommon/tracing.py @@ -9,7 +9,6 @@ from opencensus.trace.samplers import ProbabilitySampler from opencensus.trace.span import SpanKind from opencensus.trace.tracer import Tracer -from opentelemetry import trace from starlette.datastructures import QueryParams from pccommon.config import get_apis_config @@ -28,11 +27,6 @@ logger = logging.getLogger(__name__) -COLLECTION = "spatio.collection" -COLLECTIONS = "spatio.collections" -ITEM = "spatio.item" -ITEMS = "spatio.items" - exporter = ( AzureExporter( connection_string=( @@ -266,13 +260,27 @@ def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) collection_id, item_id = parse_collection_from_search( json.loads(search_json), request.method, request.query_params ) - span = trace.get_current_span() - - if collection_id is not None: - span.set_attribute(COLLECTIONS, collection_id) + tracer = Tracer( + exporter=exporter, + sampler=ProbabilitySampler(1.0), + ) - if item_id is not None: - span.set_attribute(ITEMS, item_id) + with tracer.span("main") as span: + if ( + hasattr(request.state, "parent_span") + and request.state.parent_span is not None + ): + request.state.parent_span = span + if collection_id is not None: + tracer.add_attribute_to_current_span( + attribute_key="collection", attribute_value=collection_id + ) + if item_id is not None: + tracer.add_attribute_to_current_span( + attribute_key="item", attribute_value=item_id + ) + else: + logger.warning("No 'parent_span' attribute found in request.state") def parse_collection_from_search( diff --git a/pccommon/setup.py b/pccommon/setup.py index 4197ae50..60cc7e8e 100644 --- a/pccommon/setup.py +++ b/pccommon/setup.py @@ -20,8 +20,6 @@ "html-sanitizer==2.4", # Soon available as lxml[html_clean] "lxml_html_clean==0.1.0", - "opentelemetry-api==1.21.0", - "opentelemetry-sdk==1.21.0", ] extra_reqs = { diff --git a/pcstac/pcstac/main.py b/pcstac/pcstac/main.py index 405ae28c..55ba7950 100644 --- a/pcstac/pcstac/main.py +++ b/pcstac/pcstac/main.py @@ -1,9 +1,9 @@ """FastAPI application using PGStac.""" import logging import os -from typing import Any, Awaitable, Callable, Dict +from typing import Any, Dict -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI, Request from fastapi.exceptions import RequestValidationError, StarletteHTTPException from fastapi.openapi.utils import get_openapi from fastapi.responses import ORJSONResponse @@ -15,10 +15,9 @@ from starlette.responses import PlainTextResponse from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import add_timeout, http_exception_handler +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema from pccommon.redis import connect_to_redis -from pccommon.tracing import trace_request from pcstac.api import PCStacApi from pcstac.client import PCClient from pcstac.config import ( @@ -76,14 +75,7 @@ add_timeout(app, app_settings.request_timeout) - -@app.middleware("http") -async def _request_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a trace to all requests.""" - return await trace_request(ServiceName.STAC, request, call_next) - +app.add_middleware(TraceMiddleware, service_name=app.state.service_name) # Note: If requests are being sent through an application gateway like # nginx-ingress, you may need to configure CORS through that system. diff --git a/pctiler/pctiler/main.py b/pctiler/pctiler/main.py index 3553d0f4..cb32b9dc 100755 --- a/pctiler/pctiler/main.py +++ b/pctiler/pctiler/main.py @@ -1,9 +1,9 @@ #!/usr/bin/env python3 import logging import os -from typing import Awaitable, Callable, Dict, List +from typing import Dict, List -from fastapi import FastAPI, Request, Response +from fastapi import FastAPI from fastapi.openapi.utils import get_openapi from morecantile.defaults import tms as defaultTileMatrices from morecantile.models import TileMatrixSet @@ -19,9 +19,8 @@ from pccommon.constants import X_REQUEST_ENTITY from pccommon.logging import ServiceName, init_logging -from pccommon.middleware import add_timeout, http_exception_handler +from pccommon.middleware import TraceMiddleware, add_timeout, http_exception_handler from pccommon.openapi import fixup_schema -from pccommon.tracing import trace_request from pctiler.config import get_settings from pctiler.endpoints import ( configuration, @@ -91,14 +90,7 @@ add_exception_handlers(app, MOSAIC_STATUS_CODES) -@app.middleware("http") -async def _request_middleware( - request: Request, call_next: Callable[[Request], Awaitable[Response]] -) -> Response: - """Add a trace to all requests.""" - return await trace_request(ServiceName.TILER, request, call_next) - - +app.add_middleware(TraceMiddleware, service_name=app.state.service_name) app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600") app.add_middleware(TotalTimeMiddleware)