Skip to content

Commit

Permalink
update sharing of context
Browse files Browse the repository at this point in the history
  • Loading branch information
ancalita committed May 14, 2024
1 parent 5a4b705 commit 63245c8
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def configure_cors(
app: Sanic, cors_origins: Union[Text, List[Text], None] = ""
) -> None:
"""Configure CORS origins for the given app."""

CORS(
app, resources={r"/*": {"origins": cors_origins or ""}}, automatic_options=True
)
Expand All @@ -56,7 +55,6 @@ def create_ssl_context(
ssl_password: Optional[Text] = None,
) -> Optional[SSLContext]:
"""Create a SSL context if a certificate is passed."""

if ssl_certificate:
import ssl

Expand All @@ -71,14 +69,18 @@ def create_ssl_context(

def create_argument_parser():
"""Parse all the command line arguments for the run script."""

parser = argparse.ArgumentParser(description="starts the action endpoint")
add_endpoint_arguments(parser)
utils.add_logging_level_option_arguments(parser)
utils.add_logging_file_arguments(parser)
return parser


async def load_tracer_provider(app: Sanic, tracer_provider: Optional[TracerProvider]):
"""Load the tracer provider into the Sanic app."""
app.shared_ctx.tracer_provider = tracer_provider


def create_app(
action_package_name: Union[Text, types.ModuleType],
cors_origins: Union[Text, List[Text], None] = "*",
Expand All @@ -104,9 +106,10 @@ def create_app(
executor = ActionExecutor()
executor.register_package(action_package_name)

@app.main_process_start
async def main_process_start(app: Sanic):
app.shared_ctx.tracer_provider = tracer_provider
app.register_listener(
partial(load_tracer_provider, tracer_provider=tracer_provider),
"main_process_start",
)

@app.get("/health")
async def health(_) -> HTTPResponse:
Expand All @@ -117,7 +120,9 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
tracer, context, span_name = get_tracer_and_context(app.shared_ctx.tracer_provider, request)
tracer, context, span_name = get_tracer_and_context(
app.shared_ctx.tracer_provider, request
)

with tracer.start_as_current_span(span_name, context=context) as span:
if request.headers.get("Content-Encoding") == "deflate":
Expand Down

0 comments on commit 63245c8

Please sign in to comment.