diff --git a/rasa_sdk/grpc_server.py b/rasa_sdk/grpc_server.py index 65559403..4660932c 100644 --- a/rasa_sdk/grpc_server.py +++ b/rasa_sdk/grpc_server.py @@ -3,6 +3,7 @@ import signal import asyncio +import time import grpc import logging @@ -57,6 +58,9 @@ logger = logging.getLogger(__name__) +GRPC_ACTION_SERVER_NAME = "ActionServer" + + class GRPCActionServerWebhook(action_webhook_pb2_grpc.ActionServiceServicer): """Runs webhook RPC which is served through gRPC server.""" @@ -173,12 +177,12 @@ def convert_metadata_to_multidict( if not result: return action_webhook_pb2.WebhookResponse() - set_grpc_span_attributes(span, action_call, method_name="Webhook") + _set_grpc_span_attributes(span, action_call, method_name="Webhook") response = action_webhook_pb2.WebhookResponse() return ParseDict(result, response) -def set_grpc_span_attributes( +def _set_grpc_span_attributes( span: Any, action_call: Dict[str, Any], method_name: str ) -> None: """Sets grpc span attributes.""" @@ -187,18 +191,18 @@ def set_grpc_span_attributes( span.set_attribute("grpc.method", method_name) -def get_signal_name(signal_number: int) -> str: +def _get_signal_name(signal_number: int) -> str: """Return the signal name for the given signal number.""" return signal.Signals(signal_number).name -def initialise_interrupts(server: grpc.aio.Server) -> None: +def _initialise_interrupts(server: grpc.aio.Server) -> None: """Initialise handlers for kernel signal interrupts.""" async def handle_sigint(signal_received: int): """Handle the received signal.""" logger.info( - f"Received {get_signal_name(signal_received)} signal." + f"Received {_get_signal_name(signal_received)} signal." "Stopping gRPC server..." ) await server.stop(NO_GRACE_PERIOD) @@ -213,13 +217,38 @@ async def handle_sigint(signal_received: int): ) -def _configure_health_server(server: grpc.Server): +def _initialise_health_service(server: grpc.Server): + """Initialise the health service. + + Args: + server: The gRPC server. + """ health_servicer = health.HealthServicer( experimental_non_blocking=True, experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=10), ) + health_servicer.set(GRPC_ACTION_SERVER_NAME, health_pb2.HealthCheckResponse.SERVING) health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server) - health_servicer.set("ActionServer", health_pb2.HealthCheckResponse.SERVING) + + +def _initialise_action_service(server: grpc.Server, + action_package_name: Union[str, types.ModuleType], + auto_reload: bool, + endpoints: str): + """Initialise the action service. + + Args: + server: The gRPC server. + action_package_name: Name of the package which contains the custom actions. + auto_reload: Enable auto-reloading of modules containing Action subclasses. + endpoints: Path to the endpoints file. + """ + executor = ActionExecutor() + executor.register_package(action_package_name) + tracer_provider = get_tracer_provider(endpoints) + action_webhook_pb2_grpc.add_ActionServiceServicer_to_server( + GRPCActionServerWebhook(executor, auto_reload, tracer_provider), server + ) async def run_grpc( @@ -247,15 +276,9 @@ async def run_grpc( futures.ThreadPoolExecutor(max_workers=workers), compression=grpc.Compression.Gzip, ) - initialise_interrupts(server) - executor = ActionExecutor() - executor.register_package(action_package_name) - tracer_provider = get_tracer_provider(endpoints) - action_webhook_pb2_grpc.add_ActionServiceServicer_to_server( - GRPCActionServerWebhook(executor, auto_reload, tracer_provider), server - ) - - _configure_health_server(server) + _initialise_interrupts(server) + _initialise_health_service(server) + _initialise_action_service(server, action_package_name, auto_reload, endpoints) ca_cert = file_as_bytes(ssl_ca_file) if ssl_ca_file else None