Skip to content

Commit

Permalink
Format code
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 18, 2024
1 parent b2839a3 commit fb8cd95
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ formatter:

lint:
poetry run ruff check rasa_sdk tests --ignore D
poetry run black --check rasa_sdk tests
poetry run black --exclude="rasa_sdk/grpc_py" --check rasa_sdk tests
make lint-docstrings

# Compare against `main` if no branch was provided
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
[tool.black]
line-length = 88
target-version = [ "py37", "py38", "py39", "py310",]
exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist))"
exclude = "((.eggs | .git | .mypy_cache | .pytest_cache | build | dist ))"

[tool.poetry]
name = "rasa-sdk"
Expand Down Expand Up @@ -72,6 +72,7 @@ warn_unused_ignores = true
ignore = [ "D100", "D104", "D105", "RUF005",]
line-length = 88
select = [ "D", "E", "F", "W", "RUF",]
exclude = [ "rasa_sdk/grpc_py" ]

[tool.poetry.dependencies]
python = ">=3.8,<3.11"
Expand Down
4 changes: 1 addition & 3 deletions rasa_sdk/cli/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,5 @@ def add_endpoint_arguments(parser: argparse.ArgumentParser) -> None:
help="Configuration file for the assistant as a yml file.",
)
parser.add_argument(
"--grpc",
help="Starts grpc server instead of http",
action="store_true"
"--grpc", help="Starts grpc server instead of http", action="store_true"
)
5 changes: 4 additions & 1 deletion rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ async def actions(_) -> HTTPResponse:
if auto_reload:
executor.reload()

body = [action_name_item.model_dump() for action_name_item in executor.list_actions()] # noqa: E501
body = [
action_name_item.model_dump()
for action_name_item in executor.list_actions()
]
return response.json(body, status=200)

@app.exception(Exception)
Expand Down
2 changes: 2 additions & 0 deletions rasa_sdk/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def utter_image_url(self, image: Text, **kwargs: Any) -> None:

class ActionExecutor:
"""Executes actions."""

def __init__(self) -> None:
"""Initializes the `ActionExecutor`."""
self.actions: Dict[Text, Callable] = {}
Expand Down Expand Up @@ -516,4 +517,5 @@ def list_actions(self) -> List[ActionName]:

class ActionName(BaseModel):
"""Model for action name."""

name: str = Field(alias="name")
8 changes: 5 additions & 3 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,16 @@
)
from rasa_sdk.utils import (
check_version_compatibility,
number_of_sanic_workers, file_as_bytes,
number_of_sanic_workers,
file_as_bytes,
)

logger = logging.getLogger(__name__)


class GRPCActionServerHealthCheck(health_pb2_grpc.HealthServiceServicer):
"""Runs health check RPC which is served through gRPC server."""

def __init__(self) -> None:
"""Initializes the HealthServicer."""
pass
Expand Down Expand Up @@ -232,8 +234,8 @@ async def run_grpc(
f"[::]:{port}",
server_credentials=grpc.ssl_server_credentials(
private_key_certificate_chain_pairs=[(private_key, certificate_chain)],
root_certificates = ca_cert,
require_client_auth = True if ca_cert else False,
root_certificates=ca_cert,
require_client_auth=True if ca_cert else False,
),
)
else:
Expand Down
21 changes: 8 additions & 13 deletions rasa_sdk/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class Tracker:
@classmethod
def from_dict(cls, state: "TrackerState") -> "Tracker":
"""Create a tracker from dump."""

return Tracker(
state["sender_id"],
state.get("slots", {}),
Expand All @@ -49,7 +48,6 @@ def __init__(
stack: Optional[List[Dict[Text, Any]]] = None,
) -> None:
"""Initialize the tracker."""

# list of previously seen events
self.events = events
# id of the source of the messages
Expand All @@ -72,6 +70,7 @@ def __init__(

@property
def active_form(self) -> Dict[Text, Any]:
"""Get the currently active form."""
warnings.warn(
"Use of `active_form` is deprecated. Please use `active_loop insteaad.",
DeprecationWarning,
Expand All @@ -80,7 +79,6 @@ def active_form(self) -> Dict[Text, Any]:

def current_state(self) -> Dict[Text, Any]:
"""Return the current tracker state as an object."""

if len(self.events) > 0:
latest_event_time = self.events[-1].get("timestamp")
else:
Expand All @@ -100,12 +98,11 @@ def current_state(self) -> Dict[Text, Any]:
}

def current_slot_values(self) -> Dict[Text, Any]:
"""Return the currently set values of the slots"""
"""Return the currently set values of the slots."""
return self.slots

def get_slot(self, key) -> Optional[Any]:
"""Retrieves the value of a slot."""

if key in self.slots:
return self.slots[key]
else:
Expand Down Expand Up @@ -133,7 +130,6 @@ def get_latest_entity_values(
Returns:
List of entity values.
"""

entities = self.latest_message.get("entities", [])
return (
x.get("value")
Expand All @@ -144,8 +140,7 @@ def get_latest_entity_values(
)

def get_latest_input_channel(self) -> Optional[Text]:
"""Get the name of the input_channel of the latest UserUttered event"""

"""Get the name of the input_channel of the latest UserUttered event."""
for e in reversed(self.events):
if e.get("event") == "user":
return e.get("input_channel")
Expand Down Expand Up @@ -229,7 +224,8 @@ def applied_events(self) -> List[Dict[Text, Any]]:

def undo_till_previous(event_type: Text, done_events: List[Dict[Text, Any]]):
"""Removes events from `done_events` until the first
occurrence `event_type` is found which is also removed."""
occurrence `event_type` is found which is also removed.
"""
# list gets modified - hence we need to copy events!
for e in reversed(done_events[:]):
del done_events[-1]
Expand Down Expand Up @@ -262,7 +258,6 @@ def slots_to_validate(self) -> Dict[Text, Any]:
Returns:
A mapping of extracted slot candidates and their values.
"""

slots: Dict[Text, Any] = {}
count: int = 0

Expand Down Expand Up @@ -331,7 +326,6 @@ class Action:

def name(self) -> Text:
"""Unique identifier of this simple action."""

raise NotImplementedError("An action must implement a name")

async def run(
Expand All @@ -356,7 +350,6 @@ async def run(
A dictionary of `rasa_sdk.events.Event` instances that is
returned through the endpoint
"""

raise NotImplementedError("An action must implement its run method")

def __str__(self) -> Text:
Expand All @@ -365,7 +358,9 @@ def __str__(self) -> Text:

class ActionExecutionRejection(Exception):
"""Raising this exception will allow other policies
to predict another action"""
to predict another action
.
"""

def __init__(self, action_name: Text, message: Optional[Text] = None) -> None:
self.action_name = action_name
Expand Down
2 changes: 2 additions & 0 deletions rasa_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

class Element(dict):
"""Represents an element in a list of elements in a rich message."""

__acceptable_keys = ["title", "item_url", "image_url", "subtitle", "buttons"]

def __init__(self, *args, **kwargs):
Expand All @@ -46,6 +47,7 @@ def __init__(self, *args, **kwargs):

class Button(dict):
"""Represents a button in a rich message."""

pass


Expand Down

0 comments on commit fb8cd95

Please sign in to comment.