Skip to content

Commit

Permalink
Merge pull request #1016 from RasaHQ/implement-tracing-in-action-server
Browse files Browse the repository at this point in the history
Implement tracing in action server
  • Loading branch information
Tawakalt authored Aug 2, 2023
2 parents 52ec284 + efa5270 commit 85782f3
Show file tree
Hide file tree
Showing 17 changed files with 1,112 additions and 38 deletions.
2 changes: 2 additions & 0 deletions changelog/1016.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added tracing functionality to the Rasa SDK, bringing enhanced monitoring, execution profiling and debugging capabilities to the Rasa Actions Server.
See [Rasa Documentation on Tracing](https://rasa.com/docs/rasa/monitoring/tracing/#configuring-a-tracing-backend-or-collector) to know more about configuring a tracing backend or collector.
462 changes: 453 additions & 9 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ prompt-toolkit = "^3.0,<3.0.29"
"ruamel.yaml" = ">=0.16.5,<0.18.0"
websockets = ">=10.0,<12.0"
pluggy = "^1.0.0"
opentelemetry-api = "~1.15.0"
opentelemetry-sdk = "~1.15.0"
opentelemetry-exporter-jaeger = "~1.15.0"
opentelemetry-exporter-otlp = "~1.15.0"

[tool.poetry.dev-dependencies]
pytest-cov = "^4.1.0"
Expand Down
3 changes: 3 additions & 0 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from rasa_sdk import utils
from rasa_sdk.endpoint import create_argument_parser, run
from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME
from rasa_sdk.tracing.utils import get_tracer_provider


def main_from_args(args):
Expand All @@ -17,6 +18,7 @@ def main_from_args(args):
args.logging_config_file,
)
utils.update_sanic_log_level()
tracer_provider = get_tracer_provider(args)

run(
args.actions,
Expand All @@ -26,6 +28,7 @@ def main_from_args(args):
args.ssl_keyfile,
args.ssl_password,
args.auto_reload,
tracer_provider,
)


Expand Down
70 changes: 41 additions & 29 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
import zlib
import json
from opentelemetry.sdk.trace import TracerProvider
from typing import List, Text, Union, Optional
from ssl import SSLContext

Expand All @@ -18,6 +19,7 @@
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +67,7 @@ def create_app(
action_package_name: Union[Text, types.ModuleType],
cors_origins: Union[Text, List[Text], None] = "*",
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
) -> Sanic:
"""Create a Sanic application and return it.
Expand All @@ -73,6 +76,7 @@ def create_app(
from.
cors_origins: CORS origins to allow.
auto_reload: When `True`, auto-reloading of actions is enabled.
tracer_provider: Tracer provider to use for tracing.
Returns:
A new Sanic application ready to be run.
Expand All @@ -93,34 +97,38 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
if request.headers.get("Content-Encoding") == "deflate":
# Decompress the request data using zlib
decompressed_data = zlib.decompress(request.body)
# Load the JSON data from the decompressed request data
action_call = json.loads(decompressed_data)
else:
action_call = request.json
if action_call is None:
body = {"error": "Invalid body request"}
return response.json(body, status=400)

utils.check_version_compatibility(action_call.get("version"))

if auto_reload:
executor.reload()

try:
result = await executor.run(action_call)
except ActionExecutionRejection as e:
logger.debug(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=400)
except ActionNotFoundException as e:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)

return response.json(result, status=200)
tracer, context, span_name = get_tracer_and_context(tracer_provider, request)

with tracer.start_as_current_span(span_name, context=context) as span:
if request.headers.get("Content-Encoding") == "deflate":
# Decompress the request data using zlib
decompressed_data = zlib.decompress(request.body)
# Load the JSON data from the decompressed request data
action_call = json.loads(decompressed_data)
else:
action_call = request.json
if action_call is None:
body = {"error": "Invalid body request"}
return response.json(body, status=400)

utils.check_version_compatibility(action_call.get("version"))

if auto_reload:
executor.reload()
try:
result = await executor.run(action_call)
except ActionExecutionRejection as e:
logger.debug(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=400)
except ActionNotFoundException as e:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)

set_span_attributes(span, action_call)

return response.json(result, status=200)

@app.get("/actions")
async def actions(_) -> HTTPResponse:
Expand Down Expand Up @@ -151,11 +159,15 @@ def run(
ssl_keyfile: Optional[Text] = None,
ssl_password: Optional[Text] = None,
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
) -> None:
"""Starts the action endpoint server with given config values."""
logger.info("Starting action endpoint server...")
app = create_app(
action_package_name, cors_origins=cors_origins, auto_reload=auto_reload
action_package_name,
cors_origins=cors_origins,
auto_reload=auto_reload,
tracer_provider=tracer_provider,
)
## Attach additional sanic extensions: listeners, middleware and routing
logger.info("Starting plugins...")
Expand Down
166 changes: 166 additions & 0 deletions rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations

import abc
import logging
import os
from typing import Any, Dict, Optional, Text

import grpc
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config


TRACING_SERVICE_NAME = os.environ.get("TRACING_SERVICE_NAME", "rasa_sdk")

ENDPOINTS_TRACING_KEY = "tracing"

logger = logging.getLogger(__name__)


def get_tracer_provider(endpoints_file: Text) -> Optional[TracerProvider]:
"""Configure tracing backend.
When a known tracing backend is defined in the endpoints file, this
function will configure the tracing infrastructure. When no or an unknown
tracing backend is defined, this function does nothing.
:param endpoints_file: The configuration file containing information about the
tracing backend.
:return: The `TracingProvider` to be used for all subsequent tracing.
"""
cfg = read_endpoint_config(endpoints_file, ENDPOINTS_TRACING_KEY)

if not cfg:
logger.info(
f"No endpoint for tracing type available in {endpoints_file},"
f"tracing will not be configured."
)
return None
if cfg.type == "jaeger":
tracer_provider = JaegerTracerConfigurer.configure_from_endpoint_config(cfg)
elif cfg.type == "otlp":
tracer_provider = OTLPCollectorConfigurer.configure_from_endpoint_config(cfg)
else:
logger.warning(
f"Unknown tracing type {cfg.type} read from {endpoints_file}, ignoring."
)
return None

return tracer_provider


class TracerConfigurer(abc.ABC):
"""Abstract superclass for tracing configuration.
`TracerConfigurer` is the abstract superclass from which all configurers
for different supported backends should inherit.
"""

@classmethod
@abc.abstractmethod
def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
"""Configure tracing.
This abstract method should be implemented by all concrete `TracerConfigurer`s.
It shall read the configuration from the supplied argument, configure all
necessary infrastructure for tracing, and return the `TracerProvider` to be
used for tracing purposes.
:param cfg: The configuration to be read for configuring tracing.
:return: The configured `TracerProvider`.
"""


class JaegerTracerConfigurer(TracerConfigurer):
"""The `TracerConfigurer` for a Jaeger backend."""

@classmethod
def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
"""Configure tracing for Jaeger.
This will read the Jaeger-specific configuration from the `EndpointConfig` and
create a corresponding `TracerProvider` that exports to the given Jaeger
backend.
:param cfg: The configuration to be read for configuring tracing.
:return: The configured `TracerProvider`.
"""
provider = TracerProvider(
resource=Resource.create(
{SERVICE_NAME: cfg.kwargs.get("service_name", TRACING_SERVICE_NAME)}
)
)

jaeger_exporter = JaegerExporter(
**cls._extract_config(cfg), udp_split_oversized_batches=True
)
logger.info(
f"Registered {cfg.type} endpoint for tracing. Traces will be exported to"
f" {jaeger_exporter.agent_host_name}:{jaeger_exporter.agent_port}"
)
provider.add_span_processor(BatchSpanProcessor(jaeger_exporter))

return provider

@classmethod
def _extract_config(cls, cfg: EndpointConfig) -> Dict[str, Any]:
return {
"agent_host_name": (cfg.kwargs.get("host", "localhost")),
"agent_port": (cfg.kwargs.get("port", 6831)),
"username": cfg.kwargs.get("username"),
"password": cfg.kwargs.get("password"),
}


class OTLPCollectorConfigurer(TracerConfigurer):
"""The `TracerConfigurer` for an OTLP collector backend."""

@classmethod
def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
"""Configure tracing for Jaeger.
This will read the OTLP collector-specific configuration from the
`EndpointConfig` and create a corresponding `TracerProvider` that exports to
the given OTLP collector.
Currently, this only supports insecure connections via gRPC.
:param cfg: The configuration to be read for configuring tracing.
:return: The configured `TracerProvider`.
"""
provider = TracerProvider(
resource=Resource.create(
{SERVICE_NAME: cfg.kwargs.get("service_name", TRACING_SERVICE_NAME)}
)
)

insecure = cfg.kwargs.get("insecure")

credentials = cls._get_credentials(cfg, insecure) # type: ignore

otlp_exporter = OTLPSpanExporter(
endpoint=cfg.kwargs["endpoint"],
insecure=insecure,
credentials=credentials,
)
logger.info(
f"Registered {cfg.type} endpoint for tracing."
f"Traces will be exported to {cfg.kwargs['endpoint']}"
)
provider.add_span_processor(BatchSpanProcessor(otlp_exporter))

return provider

@classmethod
def _get_credentials(
cls, cfg: EndpointConfig, insecure: bool
) -> Optional[grpc.ChannelCredentials]:
credentials = None
if not insecure and "root_certificates" in cfg.kwargs:
with open(cfg.kwargs.get("root_certificates"), "rb") as f: # type: ignore
root_cert = f.read()
credentials = grpc.ssl_channel_credentials(root_certificates=root_cert)
return credentials
64 changes: 64 additions & 0 deletions rasa_sdk/tracing/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging

import os
from typing import Any, Dict, Optional, Text
import rasa_sdk.utils


logger = logging.getLogger(__name__)
DEFAULT_ENCODING = "utf-8"


def read_endpoint_config(
filename: Text, endpoint_type: Text
) -> Optional["EndpointConfig"]:
"""Read an endpoint configuration file from disk and extract one
config."""
if not filename:
return None

try:
content = rasa_sdk.utils.read_file(filename)
content = rasa_sdk.utils.read_yaml(content)

if content.get(endpoint_type) is None:
return None

return EndpointConfig.from_dict(content[endpoint_type])
except FileNotFoundError:
logger.error(
"Failed to read endpoint configuration "
"from {}. No such file.".format(os.path.abspath(filename))
)
return None


class EndpointConfig:
"""Configuration for an external HTTP endpoint."""

def __init__(
self,
url: Optional[Text] = None,
params: Optional[Dict[Text, Any]] = None,
headers: Optional[Dict[Text, Any]] = None,
basic_auth: Optional[Dict[Text, Text]] = None,
token: Optional[Text] = None,
token_name: Text = "token",
cafile: Optional[Text] = None,
**kwargs: Any,
) -> None:
"""Creates an `EndpointConfig` instance."""
self.url = url
self.params = params or {}
self.headers = headers or {}
self.basic_auth = basic_auth or {}
self.token = token
self.token_name = token_name
self.type = kwargs.pop("store_type", kwargs.pop("type", None))
self.cafile = cafile
self.kwargs = kwargs

@classmethod
def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
return EndpointConfig(**data)
Loading

0 comments on commit 85782f3

Please sign in to comment.