Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to Sanic 22.12LTS #1103

Merged
merged 12 commits into from
Jun 4, 2024
2 changes: 2 additions & 0 deletions changelog/1103.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Upgrade Sanic to v22.12LTS.
Refactor loading of tracer provider to be triggered by Sanic `before_server_start` event listener.
35 changes: 18 additions & 17 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ select = [ "D", "E", "F", "W", "RUF",]
[tool.poetry.dependencies]
python = ">=3.8,<3.11"
coloredlogs = ">=10,<16"
sanic = "^21.12.0"
sanic = "^22.12"
typing-extensions = ">=4.1.1,<5.0.0"
Sanic-Cors = "^2.0.0"
prompt-toolkit = "^3.0,<3.0.29"
Expand All @@ -99,7 +99,7 @@ toml = "^0.10.0"
pep440-version-utils = "^0.3.0"
semantic_version = "^2.8.5"
mypy = "^1.5"
sanic-testing = "^22.3.0, <22.9.0"
sanic-testing = "^22.12"

[tool.ruff.pydocstyle]
convention = "google"
Expand Down
4 changes: 1 addition & 3 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from rasa_sdk import utils
from rasa_sdk.endpoint import create_argument_parser, run
from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME
from rasa_sdk.tracing.utils import get_tracer_provider


def main_from_args(args):
Expand All @@ -18,7 +17,6 @@ def main_from_args(args):
args.logging_config_file,
)
utils.update_sanic_log_level()
tracer_provider = get_tracer_provider(args)

run(
args.actions,
Expand All @@ -28,7 +26,7 @@ def main_from_args(args):
args.ssl_keyfile,
args.ssl_password,
args.auto_reload,
tracer_provider,
args.endpoints,
)


Expand Down
65 changes: 48 additions & 17 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import warnings
import zlib
import json
from functools import partial
from typing import List, Text, Union, Optional
from ssl import SSLContext
from sanic import Sanic, response
from sanic.response import HTTPResponse
from sanic.worker.loader import AppLoader

# catching:
# - all `pkg_resources` deprecation warning from multiple dependencies
Expand All @@ -24,16 +26,23 @@
category=DeprecationWarning,
message="distutils Version classes are deprecated",
)
from opentelemetry.sdk.trace import TracerProvider
from sanic_cors import CORS
from sanic.request import Request
from rasa_sdk import utils
from rasa_sdk.cli.arguments import add_endpoint_arguments
from rasa_sdk.constants import DEFAULT_KEEP_ALIVE_TIMEOUT, DEFAULT_SERVER_PORT
from rasa_sdk.constants import (
DEFAULT_ENDPOINTS_PATH,
DEFAULT_KEEP_ALIVE_TIMEOUT,
DEFAULT_SERVER_PORT,
)
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes
from rasa_sdk.tracing.utils import (
get_tracer_and_context,
get_tracer_provider,
set_span_attributes,
)

logger = logging.getLogger(__name__)

Expand All @@ -42,7 +51,6 @@ def configure_cors(
app: Sanic, cors_origins: Union[Text, List[Text], None] = ""
) -> None:
"""Configure CORS origins for the given app."""

CORS(
app, resources={r"/*": {"origins": cors_origins or ""}}, automatic_options=True
)
Expand All @@ -54,7 +62,6 @@ def create_ssl_context(
ssl_password: Optional[Text] = None,
) -> Optional[SSLContext]:
"""Create a SSL context if a certificate is passed."""

if ssl_certificate:
import ssl

Expand All @@ -69,19 +76,23 @@ def create_ssl_context(

def create_argument_parser():
"""Parse all the command line arguments for the run script."""

parser = argparse.ArgumentParser(description="starts the action endpoint")
add_endpoint_arguments(parser)
utils.add_logging_level_option_arguments(parser)
utils.add_logging_file_arguments(parser)
return parser


async def load_tracer_provider(endpoints: str, app: Sanic):
"""Load the tracer provider into the Sanic app."""
tracer_provider = get_tracer_provider(endpoints)
app.ctx.tracer_provider = tracer_provider


def create_app(
action_package_name: Union[Text, types.ModuleType],
cors_origins: Union[Text, List[Text], None] = "*",
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
) -> Sanic:
"""Create a Sanic application and return it.

Expand All @@ -90,7 +101,6 @@ def create_app(
from.
cors_origins: CORS origins to allow.
auto_reload: When `True`, auto-reloading of actions is enabled.
tracer_provider: Tracer provider to use for tracing.

Returns:
A new Sanic application ready to be run.
Expand All @@ -102,6 +112,8 @@ def create_app(
executor = ActionExecutor()
executor.register_package(action_package_name)

app.ctx.tracer_provider = None

@app.get("/health")
async def health(_) -> HTTPResponse:
"""Ping endpoint to check if the server is running and well."""
Expand All @@ -111,7 +123,9 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
tracer, context, span_name = get_tracer_and_context(tracer_provider, request)
tracer, context, span_name = get_tracer_and_context(
request.app.ctx.tracer_provider, request
)

with tracer.start_as_current_span(span_name, context=context) as span:
if request.headers.get("Content-Encoding") == "deflate":
Expand Down Expand Up @@ -173,27 +187,44 @@ def run(
ssl_keyfile: Optional[Text] = None,
ssl_password: Optional[Text] = None,
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
endpoints: str = DEFAULT_ENDPOINTS_PATH,
keep_alive_timeout: int = DEFAULT_KEEP_ALIVE_TIMEOUT,
) -> None:
"""Starts the action endpoint server with given config values."""
logger.info("Starting action endpoint server...")
app = create_app(
action_package_name,
cors_origins=cors_origins,
auto_reload=auto_reload,
tracer_provider=tracer_provider,
loader = AppLoader(
factory=partial(
create_app,
action_package_name,
cors_origins=cors_origins,
auto_reload=auto_reload,
),
)
app = loader.load()

app.config.KEEP_ALIVE_TIMEOUT = keep_alive_timeout
## Attach additional sanic extensions: listeners, middleware and routing

app.register_listener(
partial(load_tracer_provider, endpoints),
"before_server_start",
)

# Attach additional sanic extensions: listeners, middleware and routing
logger.info("Starting plugins...")
plugin_manager().hook.attach_sanic_app_extensions(app=app)

ssl_context = create_ssl_context(ssl_certificate, ssl_keyfile, ssl_password)
protocol = "https" if ssl_context else "http"
host = os.environ.get("SANIC_HOST", "0.0.0.0")

logger.info(f"Action endpoint is up and running on {protocol}://{host}:{port}")
app.run(host, port, ssl=ssl_context, workers=utils.number_of_sanic_workers())
app.run(
host=host,
port=port,
ssl=ssl_context,
workers=utils.number_of_sanic_workers(),
legacy=True,
)


if __name__ == "__main__":
Expand Down
20 changes: 6 additions & 14 deletions rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
from rasa_sdk.tracing import config
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
Expand All @@ -9,25 +8,18 @@
from typing import Optional, Tuple, Any, Text


def get_tracer_provider(
cmdline_arguments: argparse.Namespace,
) -> Optional[TracerProvider]:
def get_tracer_provider(endpoints_file: str) -> Optional[TracerProvider]:
"""Gets the tracer provider from the command line arguments."""
tracer_provider = None
endpoints_file = ""
if "endpoints" in cmdline_arguments:
endpoints_file = cmdline_arguments.endpoints

if endpoints_file is not None:
tracer_provider = config.get_tracer_provider(endpoints_file)
config.configure_tracing(tracer_provider)
tracer_provider = config.get_tracer_provider(endpoints_file)
config.configure_tracing(tracer_provider)

return tracer_provider


def get_tracer_and_context(
tracer_provider: Optional[TracerProvider], request: Request
) -> Tuple[Any, Any, Text]:
"""Gets tracer and context"""
"""Gets tracer and context."""
span_name = "create_app.webhook"
if tracer_provider is None:
tracer = trace.get_tracer(span_name)
Expand All @@ -39,7 +31,7 @@ def get_tracer_and_context(


def set_span_attributes(span: Any, action_call: dict) -> None:
"""Sets span attributes"""
"""Sets span attributes."""
tracker = action_call.get("tracker", {})
set_span_attributes = {
"http.method": "POST",
Expand Down
Loading
Loading