Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove BaseHTTPMiddlewares , Ensure origin host is used in STAC links #190

Merged
merged 15 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions pccommon/pccommon/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import re
from typing import Awaitable, Callable, List, Optional, Tuple, Union, cast

import fastapi
from fastapi import Request, Response
from opencensus.ext.azure.trace_exporter import AzureExporter
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.span import SpanKind
from opencensus.trace.tracer import Tracer
from opentelemetry import trace
from starlette.datastructures import QueryParams

from pccommon.config import get_apis_config
from pccommon.constants import (
Expand All @@ -25,6 +28,11 @@
logger = logging.getLogger(__name__)


COLLECTION = "spatio.collection"
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
COLLECTIONS = "spatio.collections"
ITEM = "spatio.item"
ITEMS = "spatio.items"

exporter = (
AzureExporter(
connection_string=(
Expand Down Expand Up @@ -249,3 +257,49 @@ def _iter_cql(cql: dict, property_name: str) -> Optional[Union[str, List[str]]]:
return result
# No collection was found
return None


def add_stac_attributes_from_search(search_json: str, request: fastapi.Request) -> None:
"""
Try to add the Collection ID and Item ID from a search to the current span.
"""
collection_id, item_id = parse_collection_from_search(
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
json.loads(search_json), request.method, request.query_params
)
span = trace.get_current_span()
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved

if collection_id is not None:
span.set_attribute(COLLECTIONS, collection_id)

if item_id is not None:
span.set_attribute(ITEMS, item_id)


def parse_collection_from_search(
body: dict,
method: str,
query_params: QueryParams,
) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the collection id from a search request.

The search endpoint is a bit of a special case. If it's a GET, the collection
and item ids are in the querystring. If it's a POST, the collection and item may
be in either a CQL-JSON or CQL2-JSON filter body, or a query/stac-ql body.
"""
if method.lower() == "get":
collection_id = query_params.get("collections")
item_id = query_params.get("ids")
return (collection_id, item_id)
elif method.lower() == "post":
try:
if "collections" in body:
return _parse_queryjson(body)
elif "filter" in body:
return _parse_cqljson(body["filter"])
except json.JSONDecodeError as e:
logger.warning(
"Unable to parse search body as JSON. Ignoring collection"
f"parameter. {e}"
)
return (None, None)
2 changes: 2 additions & 0 deletions pccommon/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
"html-sanitizer==2.4",
# Soon available as lxml[html_clean]
"lxml_html_clean==0.1.0",
"opentelemetry-api==1.21.0",
"opentelemetry-sdk==1.21.0",
]

extra_reqs = {
Expand Down
3 changes: 3 additions & 0 deletions pcstac/pcstac/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pccommon.constants import DEFAULT_COLLECTION_REGION
from pccommon.logging import get_custom_dimensions
from pccommon.redis import back_pressure, cached_result, rate_limit
from pccommon.tracing import add_stac_attributes_from_search
from pcstac.config import API_DESCRIPTION, API_LANDING_PAGE_ID, API_TITLE, get_settings
from pcstac.contants import (
CACHE_KEY_COLLECTION,
Expand Down Expand Up @@ -227,6 +228,8 @@ async def _fetch() -> ItemCollection:
return item_collection

search_json = search_request.json()
add_stac_attributes_from_search(search_json, request)

logger.info(
"STAC: Item search body",
extra=get_custom_dimensions({"search_body": search_json}, request),
Expand Down
14 changes: 12 additions & 2 deletions pcstac/pcstac/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""FastAPI application using PGStac."""
import logging
import os
from typing import Any, Dict
from typing import Any, Awaitable, Callable, Dict

from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, Response
from fastapi.exceptions import RequestValidationError, StarletteHTTPException
from fastapi.openapi.utils import get_openapi
from fastapi.responses import ORJSONResponse
Expand All @@ -18,6 +18,7 @@
from pccommon.middleware import add_timeout, http_exception_handler
from pccommon.openapi import fixup_schema
from pccommon.redis import connect_to_redis
from pccommon.tracing import trace_request
from pcstac.api import PCStacApi
from pcstac.client import PCClient
from pcstac.config import (
Expand Down Expand Up @@ -75,6 +76,15 @@

add_timeout(app, app_settings.request_timeout)


@app.middleware("http")
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved
async def _request_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a trace to all requests."""
return await trace_request(ServiceName.STAC, request, call_next)
mmcfarland marked this conversation as resolved.
Show resolved Hide resolved


# Note: If requests are being sent through an application gateway like
# nginx-ingress, you may need to configure CORS through that system.
app.add_middleware(
Expand Down
14 changes: 12 additions & 2 deletions pctiler/pctiler/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
import logging
import os
from typing import Dict, List
from typing import Awaitable, Callable, Dict, List

from fastapi import FastAPI
from fastapi import FastAPI, Request, Response
from fastapi.openapi.utils import get_openapi
from morecantile.defaults import tms as defaultTileMatrices
from morecantile.models import TileMatrixSet
Expand All @@ -21,6 +21,7 @@
from pccommon.logging import ServiceName, init_logging
from pccommon.middleware import add_timeout, http_exception_handler
from pccommon.openapi import fixup_schema
from pccommon.tracing import trace_request
from pctiler.config import get_settings
from pctiler.endpoints import (
configuration,
Expand Down Expand Up @@ -89,6 +90,15 @@
add_exception_handlers(app, DEFAULT_STATUS_CODES)
add_exception_handlers(app, MOSAIC_STATUS_CODES)


@app.middleware("http")
async def _request_middleware(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Add a trace to all requests."""
return await trace_request(ServiceName.TILER, request, call_next)


app.add_middleware(CacheControlMiddleware, cachecontrol="public, max-age=3600")
app.add_middleware(TotalTimeMiddleware)

Expand Down
Loading