Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ATO-2609] Add gRPC tracing #1117

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1117.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Tracing is supported for actions called over gRPC protocol.
42 changes: 38 additions & 4 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import zlib
import json
from functools import partial
from typing import List, Text, Union, Optional
from typing import List, Text, Union, Optional, Any
from ssl import SSLContext

from multidict import MultiDict
from sanic import Sanic, response
from sanic.compat import Header
from sanic.response import HTTPResponse
from sanic.worker.loader import AppLoader

Expand Down Expand Up @@ -127,8 +130,21 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
tracer, context, span_name = get_tracer_and_context(
request.app.ctx.tracer_provider, request
span_name = "create_app.webhook"

def header_to_multi_dict(headers: Header) -> MultiDict:
return MultiDict(
[
(key, value)
for key, value in headers.items()
if key.lower() not in ("content-length", "content-encoding")
]
)

tracer, context = get_tracer_and_context(
span_name=span_name,
tracer_provider=request.app.ctx.tracer_provider,
tracing_carrier=header_to_multi_dict(request.headers),
)

with tracer.start_as_current_span(span_name, context=context) as span:
Expand Down Expand Up @@ -162,7 +178,12 @@ async def webhook(request: Request) -> HTTPResponse:
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=449)

set_span_attributes(span, action_call)
set_http_span_attributes(
span,
action_call,
http_method="POST",
route="/webhook",
)

return response.json(result, status=200)

Expand Down Expand Up @@ -238,6 +259,19 @@ def run(
)


def set_http_span_attributes(
span: Any,
action_call: dict,
http_method: str,
route: str,
) -> None:
"""Sets http span attributes."""
set_span_attributes(span, action_call)
if span.is_recording():
span.set_attribute("http.method", http_method)
span.set_attribute("http.route", route)


if __name__ == "__main__":
import rasa_sdk.__main__

Expand Down
40 changes: 31 additions & 9 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import grpc
import logging
import types
from typing import Union, Optional
from typing import Union, Optional, Any, Dict
from concurrent import futures
from grpc import aio
from google.protobuf.json_format import MessageToDict, ParseDict
from opentelemetry import trace
from grpc.aio import Metadata
from multidict import MultiDict

from rasa_sdk.constants import (
DEFAULT_SERVER_PORT,
Expand Down Expand Up @@ -43,6 +44,8 @@
from rasa_sdk.tracing.utils import (
get_tracer_provider,
TracerProvider,
get_tracer_and_context,
set_span_attributes,
)
from rasa_sdk.utils import (
check_version_compatibility,
Expand Down Expand Up @@ -135,13 +138,23 @@ async def Webhook(
gRPC response.
"""
span_name = "GRPCActionServerWebhook.Webhook"
tracer = (
self.tracer_provider.get_tracer(span_name)
if self.tracer_provider
else trace.get_tracer(span_name)
invocation_metadata = context.invocation_metadata()

def convert_metadata_to_multidict(
metadata: Optional[Metadata],
) -> Optional[MultiDict]:
"""Convert list of tuples to multidict."""
if not metadata:
return None
return MultiDict(metadata)

tracer, tracing_context = get_tracer_and_context(
span_name=span_name,
tracer_provider=self.tracer_provider,
tracing_carrier=convert_metadata_to_multidict(invocation_metadata),
)

with tracer.start_as_current_span(span_name):
with tracer.start_as_current_span(span_name, context=tracing_context) as span:
check_version_compatibility(request.version)
if self.auto_reload:
self.executor.reload()
Expand Down Expand Up @@ -179,12 +192,21 @@ async def Webhook(
return action_webhook_pb2.WebhookResponse()
if not result:
return action_webhook_pb2.WebhookResponse()
# set_span_attributes(span, request)
response = action_webhook_pb2.WebhookResponse()

set_grpc_span_attributes(span, action_call, method_name="Webhook")
response = action_webhook_pb2.WebhookResponse()
return ParseDict(result, response)


def set_grpc_span_attributes(
span: Any, action_call: Dict[str, Any], method_name: str
) -> None:
"""Sets grpc span attributes."""
set_span_attributes(span, action_call)
if span.is_recording():
span.set_attribute("grpc.method", method_name)


def get_signal_name(signal_number: int) -> str:
"""Return the signal name for the given signal number."""
return signal.Signals(signal_number).name
Expand Down
27 changes: 15 additions & 12 deletions rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from multidict import MultiDict

from rasa_sdk.tracing import config
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from opentelemetry.sdk.trace import TracerProvider
from sanic.request import Request

from typing import Optional, Tuple, Any, Text, Union
from typing import Optional, Tuple, Any


def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]:
Expand All @@ -17,34 +18,36 @@ def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]:


def get_tracer_and_context(
tracer_provider: Optional[TracerProvider], request: Union[Request]
) -> Tuple[Any, Any, Text]:
span_name: str,
tracer_provider: Optional[TracerProvider],
tracing_carrier: Optional[MultiDict],
) -> Tuple[Any, Any]:
"""Gets tracer and context."""
span_name = "create_app.webhook"

if tracer_provider is None:
tracer = trace.get_tracer(span_name)
context = None
else:
tracer = tracer_provider.get_tracer(span_name)
context = TraceContextTextMapPropagator().extract(request.headers)
return (tracer, context, span_name)
context = (
TraceContextTextMapPropagator().extract(tracing_carrier)
if tracing_carrier
else None
)
return tracer, context


def set_span_attributes(span: Any, action_call: dict) -> None:
"""Sets span attributes."""
tracker = action_call.get("tracker", {})
set_span_attributes = {
"http.method": "POST",
"http.route": "/webhook",
span_attributes = {
"next_action": action_call.get("next_action"),
"version": action_call.get("version"),
"sender_id": tracker.get("sender_id"),
"message_id": tracker.get("latest_message", {}).get("message_id"),
}

if span.is_recording():
for key, value in set_span_attributes.items():
for key, value in span_attributes.items():
span.set_attribute(key, value)

return None
12 changes: 8 additions & 4 deletions tests/tracing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_get_tracer_provider_returns_none_if_tracing_is_not_configured() -> None

def test_get_tracer_provider_returns_provider() -> None:
"""Tests that get_tracer_provider returns a TracerProvider
if tracing is configured."""
if tracing is configured.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--endpoints", type=str, default=None)

Expand All @@ -58,7 +59,7 @@ def test_get_tracer_provider_returns_provider() -> None:


def test_get_tracer_and_context() -> None:
"""Tests that get_tracer_and_context returns a ProxyTracer and span name"""
"""Tests that get_tracer_and_context returns a ProxyTracer and span name."""
data = {
"next_action": "custom_action",
"version": "1.0.0",
Expand All @@ -70,8 +71,11 @@ def test_get_tracer_and_context() -> None:
}
app = ep.create_app(None)
request, _ = app.test_client.post("/webhook", data=json.dumps(data))
tracer, context, span_name = get_tracer_and_context(None, request)
tracer, context = get_tracer_and_context(
span_name="create_app.webhook",
tracer_provider=None,
tracing_carrier=request.headers,
)

assert isinstance(tracer, ProxyTracer)
assert span_name == "create_app.webhook"
assert context is None
Loading