Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Messenger Interface rework #357

Open
wants to merge 55 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
70306f4
poetry.lock update
ZergLev May 6, 2024
af90b04
trying graceful termination
ZergLev May 6, 2024
9d7ecf0
graceful termination is done exclusively within the interface class (…
ZergLev May 8, 2024
0c377f1
mistake fixed
ZergLev May 8, 2024
8360976
sigint handling moved to pipeline + custom stop funcs added
ZergLev May 13, 2024
8476b16
feature works now
ZergLev May 13, 2024
4d83016
formatted code
ZergLev May 13, 2024
2e0dedc
fixed typo
ZergLev May 13, 2024
93b5359
interfaces enhanced with asyncio, changes to graceful termination added
ZergLev May 16, 2024
d759fba
added a check to signal handler(unfinished)
ZergLev May 16, 2024
01b0002
formatted with poetry
ZergLev May 16, 2024
52eb5bf
PollingMessengerInterface() overhaul draft
ZergLev May 22, 2024
88753a8
formatted with poetry
ZergLev May 22, 2024
5f98eda
refactor
ZergLev May 22, 2024
591bd8f
testing, but connect() doesn't work
ZergLev May 23, 2024
fd89b12
writing tests
ZergLev May 24, 2024
98a13b4
old unit-tests work with this now, a few mistakes fixed
ZergLev May 27, 2024
9889199
telegram bot works, but graceful termination apparently does not
ZergLev May 27, 2024
e956cc3
Trying the echo test
ZergLev Jun 16, 2024
e238432
Trying the echo test
ZergLev Jun 16, 2024
1f88e47
Trying the echo test
ZergLev Jun 16, 2024
97cec27
echo test draft (doesn't launch)
ZergLev Jun 16, 2024
e601ed6
first test works, several bug fixes
ZergLev Jun 19, 2024
7d8c68c
ContextLock() test added
ZergLev Jun 19, 2024
8ebf6ec
comments changed
ZergLev Jun 19, 2024
e366a61
debug output removed
ZergLev Jun 19, 2024
09fe10f
confusing comment removed
ZergLev Jun 19, 2024
0cecb80
comment changes
ZergLev Jun 19, 2024
1c24131
new tests moved to a separate file
ZergLev Jun 20, 2024
b49dac7
more tests added
ZergLev Jun 20, 2024
876ce8d
poll_timeout added + test changed
ZergLev Jun 21, 2024
f91910b
typo corrected
ZergLev Jun 21, 2024
585a37d
add siginthandler to async loop
RLKRo Jun 21, 2024
4180e17
fix test class
RLKRo Jun 21, 2024
0059bd3
adding worker timeouts and cleanup
ZergLev Jun 26, 2024
c3f18a7
Merge branch 'feat/graceful_termination' of https://github.com/ZergLe…
ZergLev Jun 26, 2024
65329f8
new _worker() seems to be working (it's awaited)
ZergLev Jun 26, 2024
30369f6
all tests but one working, can't call shutdown()
ZergLev Jun 26, 2024
ce9ac81
all tests working
ZergLev Jun 26, 2024
c765c18
ContextLock() moved to pipeline.py
ZergLev Jul 1, 2024
83ebe7f
formatted with poetry
ZergLev Jul 1, 2024
b99c4eb
Merge branch 'dev' into feat/graceful_termination
RLKRo Aug 2, 2024
8715c2f
Update tests/messengers/common/test_messenger_interface.py
ZergLev Aug 2, 2024
a829cf5
Update chatsky/messengers/common/interface.py
ZergLev Aug 5, 2024
16b049a
review changes started, bugs appeared
ZergLev Aug 7, 2024
a8607c6
Merge branch 'feat/graceful_termination' of https://github.com/ZergLe…
ZergLev Aug 7, 2024
5656437
moving graceful termination to pipeline, windows support added back i…
ZergLev Aug 16, 2024
cd16255
in the process of fixing bugs, docs partially added
ZergLev Aug 16, 2024
3282d18
new LongpollingMessengerInterface drafted + removing run_in_foregroun…
ZergLev Aug 21, 2024
b2140a5
Merge branch 'dev' into feat/graceful_termination
ZergLev Aug 21, 2024
6fde6c3
lint
ZergLev Aug 21, 2024
2529537
Merge branch 'feat/graceful_termination' of https://github.com/ZergLe…
ZergLev Aug 21, 2024
6334b8e
lint
ZergLev Aug 21, 2024
9a2381a
fully removed run_in_foreground, some changes to graceful termination
ZergLev Aug 23, 2024
b8c99d9
in the process of switching to BaseModel + other changes
ZergLev Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 167 additions & 42 deletions chatsky/messengers/common/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,19 @@
import logging
from pathlib import Path
from tempfile import gettempdir
import signal
from functools import partial
import time # Don't forget to remove this
import contextlib

from typing import Optional, Any, List, Tuple, Hashable, TYPE_CHECKING, Type

if TYPE_CHECKING:
from chatsky.script import Context, Message
from chatsky.pipeline.types import PipelineRunnerFunction
from chatsky.messengers.common.types import PollingInterfaceLoopFunction
from chatsky.script.core.message import Attachment
from chatsky.pipeline.pipeline.pipeline import Pipeline

logger = logging.getLogger(__name__)

Expand All @@ -28,8 +34,15 @@ class MessengerInterface(abc.ABC):
It is responsible for connection between user and pipeline, as well as for request-response transactions.
"""

def __init__(self):
self.task = None
self.running_in_foreground = False
self.running = True
self.stopped = False
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
self.shielded = False # This determines whether the interface wants to be shut down with task.cancel() or just switching a flag. Let's say PollingMessengerInterface wants task.cancel()
ZergLev marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
async def connect(self, pipeline_runner: PipelineRunnerFunction):
async def connect(self, *args):
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
"""
Method invoked when message interface is instantiated and connection is established.
May be used for sending an introduction message or displaying general bot information.
Expand All @@ -39,6 +52,50 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction):
"""
raise NotImplementedError

# This is an optional method, so no need to make it abstract, I think.
async def cleanup(self, *args):
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
pass
ZergLev marked this conversation as resolved.
Show resolved Hide resolved

async def run_in_foreground(
self, pipeline: Pipeline, loop: PollingInterfaceLoopFunction = lambda: True, timeout: float = 0, *args
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this method?
Maybe pipeline should create a task from the interface and await on it on its own?
We are going to add support for multiple messenger interfaces and I think that interfaces for using one or multiple ifaces shouldn't differ.

Copy link
Collaborator Author

@ZergLev ZergLev Aug 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm for removing it, but not sure how to do this properly. Let's say this code is moved to pipeline.py, where connect() will be called from a_run(). Then tests will want to call connect() if they want to pass parameters, which means all code in pipeline.a_run() will be bypassed. Basically, we're not giving the user an option to both run code from pipeline.run and pass their parameters into connect().
What if all those parameters will be passed in MessengerInterface constructor parameters? Not sure though, these parameters don't really look like constructor parameters, it looks a bit off. If we do go through with this, Pydantic's BaseModel may be useful.

self.running_in_foreground = True
self.pipeline = pipeline

async_loop = asyncio.get_running_loop()
async_loop.add_signal_handler(signal.SIGINT, partial(pipeline.sigint_handler, async_loop))
# TO-DO: Clean this up and/or think this through (connect() methods are different for various MessengerInterface() classes)
if isinstance(self.pipeline.messenger_interface, PollingMessengerInterface):
self.task = asyncio.create_task(self.connect(loop=loop, timeout=timeout, *args))
elif isinstance(self.pipeline.messenger_interface, CallbackMessengerInterface):
self.task = asyncio.create_task(self.connect(self.pipeline._run_pipeline, *args))
else:
self.task = asyncio.create_task(self.connect(self.pipeline._run_pipeline, *args))
ZergLev marked this conversation as resolved.
Show resolved Hide resolved

try:
await self.task
except asyncio.CancelledError:
await asyncio.sleep(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments on why we need to sleep.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why myself. I think it's to set self.running to False from shutdown method, ending the polling_loop. I'll try to change that line to self.running = False and see what happens

await self.cleanup()
ZergLev marked this conversation as resolved.
Show resolved Hide resolved

self.stopped = True

# Placeholder for any cleanup code.

# I can make shutdown() work for PollingMessengerInterface, but I don't know the structure of Telegram Messenger Interfaces. Right now, this ends the main task and sets a flag self.running to False, so that any async tasks in loops can see that and turn off as soon as they are done.
async def shutdown(self):
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"messenger_interface.shutdown() called - shutting down interface")
self.running = False
self.task.cancel()
try:
await self.task
except asyncio.CancelledError:
# raise asyncio.CancelledError
# await asyncio.sleep(0)
if not self.stopped:
raise asyncio.CancelledError
RLKRo marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"{type(self).__name__} has stopped working - SIGINT received")


class MessengerInterfaceWithAttachments(MessengerInterface, abc.ABC):
"""
Expand Down Expand Up @@ -94,74 +151,142 @@ class PollingMessengerInterface(MessengerInterface):
Polling message interface runs in a loop, constantly asking users for a new input.
"""

def __init__(self):
self.request_queue = asyncio.Queue()
self.cancel_on_shutdown = True # Would like task.cancel(). (Not done yet)
self.number_of_workers = 2
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
# Could make this an argument of connect(), but people can just type interface.number_of_workers = their_number before creating pipeline. Interface features like timeouts could be a tutorial, actually. But it's not really necessary or in demand.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make Messenger Interfaces BaseModel to make handling these flags easier?

self._worker_tasks = []
super().__init__()

@abc.abstractmethod
def _request(self) -> List[Tuple[Message, Hashable]]:
async def _respond(self, ctx_id, last_response):
"""
Method used for sending users request for their input.
Method used for sending users responses for their last input.

:return: A list of tuples: user inputs and context ids (any user ids) associated with the inputs.
:param ctx_id: Context id, specifies the user id. Without multiple messenger interfaces it's basically a redundant parameter, because this function is just a more complex `print(last_response)`. (Change before merge)
:param last_response: Latest response from the pipeline which should be relayed to the specified user.
"""
raise NotImplementedError

@abc.abstractmethod
def _respond(self, responses: List[Context]):
async def _process_request(self, ctx_id, update: Message, pipeline: Pipeline):
"""
Method used for sending users responses for their last input.
Process a new update for ctx.
"""
context = await pipeline._run_pipeline(update, ctx_id)
await self._respond(ctx_id, context.last_response)

:param responses: A list of contexts, representing dialogs with the users;
`last_response`, `id` and some dialog info can be extracted from there.
async def _worker_job(self):
"""
raise NotImplementedError
Obtain Lock over the current context,
Process the update and send it.
"""
request = await self.request_queue.get()
if request is not None:
(ctx_id, update) = request
async with self.pipeline.context_lock[ctx_id]: # get exclusive access to this context among interfaces
# Trying to see if _process_request works at all. Looks like it does it just fine, actually
# await self._process_request(ctx_id, update, self.pipeline)
# Doesn't work in a thread for some reason - it goes into an infinite cycle.
# """
await asyncio.to_thread( # [optional] execute in a separate thread to avoid blocking
self._process_request, ctx_id, update, self.pipeline
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does to_thread work?
Clean up this code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't. It just goes into an infinite cycle with no outputs. I guess I could look through logs file, but I haven't done it yet.

# """
return False
else:
return True

def _on_exception(self, e: BaseException):
# This worker doesn't save the request and basically deletes it from the queue in case it can't process it. An option to save the request may be fitting? Maybe with an amount of retries.
async def _worker(self, worker_timeout: float):
while self.running or not self.request_queue.empty():
try:
no_more_jobs = await asyncio.wait_for(self._worker_job(), timeout=worker_timeout)
if no_more_jobs:
logger.info(
f"Worker finished working - stop signal received and remaining requests have been processed."
)
# This logging is incorrect right now, request queue running out isn't handled and it's mistakenly called a stop signal.
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
break
except TimeoutError:
# If there's just no requests coming, worker will keep sending this log message.
# Looks really bad.
logger.info("worker just timed out. A request *may* have been lost.")
RLKRo marked this conversation as resolved.
Show resolved Hide resolved

@abc.abstractmethod
async def _get_updates(self) -> list[tuple[Any, Message]]:
"""
Method that is called on polling cycle exceptions, in some cases it should show users the exception.
By default, it logs all exit exceptions to `info` log and all non-exit exceptions to `error`.
Obtain updates from another server

:param e: The exception.
Example:
self.bot.request_updates()
"""
if isinstance(e, Exception):
logger.error(f"Exception in {type(self).__name__} loop!", exc_info=e)
else:
logger.info(f"{type(self).__name__} has stopped polling.")

async def _polling_job(self, poll_timeout: float):
try:
coroutine = asyncio.wait_for(self._get_updates(), timeout=poll_timeout)
received_updates = await coroutine
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
if received_updates is not None:
for update in received_updates:
await self.request_queue.put(update)
except TimeoutError:
# self.shutdown()
# Shutting down is probably too extreme, unless it's several times in a row maybe.
logger.info("polling_job failed - timed out")
ZergLev marked this conversation as resolved.
Show resolved Hide resolved

async def _polling_loop(
self,
pipeline_runner: PipelineRunnerFunction,
loop: PollingInterfaceLoopFunction = lambda: True,
poll_timeout: float = None,
timeout: float = 0,
):
"""
Method running the request - response cycle once.
"""
user_updates = self._request()
responses = [await pipeline_runner(request, ctx_id) for request, ctx_id in user_updates]
self._respond(responses)
await asyncio.sleep(timeout)
try:
while loop() and self.running:
await asyncio.shield(self._polling_job(poll_timeout)) # shield from cancellation
await asyncio.sleep(timeout)
finally:
self.running = False
print("loop ending")
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
logger.info(
f"polling_loop stopped working - either the stop signal was received or the loop() condition was false."
)
# If there're no more jobs/stop signal received, a special 'None' request is sent to the queue (one for each worker), they shut down the workers.
# In case of more workers than two, change the number of 'None' requests to the new number of workers.
for i in range(self.number_of_workers):
self.request_queue.put_nowait(None)

async def connect(
self,
pipeline_runner: PipelineRunnerFunction,
loop: PollingInterfaceLoopFunction = lambda: True,
poll_timeout: float = None,
worker_timeout: float = None,
timeout: float = 0,
):
"""
Method, running a request - response cycle in a loop.
The looping behavior is determined by `loop` and `timeout`,
for most cases the loop itself shouldn't be overridden.
# Saving strong references to workers, so that they can be cleaned up properly.
# shield() creates a task just like create_task()
for i in range(self.number_of_workers):
task = asyncio.shield(self._worker(worker_timeout))
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
self._worker_tasks.append(task)
await self._polling_loop(loop=loop, poll_timeout=poll_timeout, timeout=timeout)

# Workers for PollingMessengerInterface are awaited here.
async def cleanup(self):
await super().cleanup()
await asyncio.wait(self._worker_tasks)
# await asyncio.gather(*self._worker_tasks)
# Blocks until all workers are done

:param pipeline_runner: A function that should process user request and return context;
usually it's a :py:meth:`~chatsky.pipeline.pipeline.pipeline.Pipeline._run_pipeline` function.
:param loop: a function that determines whether polling should be continued;
called in each cycle, should return `True` to continue polling or `False` to stop.
:param timeout: a time interval between polls (in seconds).
def _on_exception(self, e: BaseException):
"""
while loop():
try:
await self._polling_loop(pipeline_runner, timeout)
Method that is called on polling cycle exceptions, in some cases it should show users the exception.
By default, it logs all exit exceptions to `info` log and all non-exit exceptions to `error`.

except BaseException as e:
self._on_exception(e)
break
:param e: The exception.
"""
if isinstance(e, Exception):
logger.error(f"Exception in {type(self).__name__} loop!", exc_info=e)
else:
logger.info(f"{type(self).__name__} has stopped polling.")


class CallbackMessengerInterface(MessengerInterface):
Expand Down
17 changes: 6 additions & 11 deletions chatsky/messengers/console.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Any, Hashable, List, Optional, TextIO, Tuple
from uuid import uuid4
from chatsky.messengers.common.interface import PollingMessengerInterface
from chatsky.pipeline.types import PipelineRunnerFunction
from chatsky.script.core.context import Context
from chatsky.script.core.message import Message


Expand All @@ -12,9 +10,6 @@ class CLIMessengerInterface(PollingMessengerInterface):
This message interface can maintain dialog with one user at a time only.
"""

supported_request_attachment_types = set()
supported_response_attachment_types = set()

def __init__(
self,
intro: Optional[str] = None,
Expand All @@ -29,13 +24,13 @@ def __init__(
self._prompt_response: str = prompt_response
self._descriptor: Optional[TextIO] = out_descriptor

def _request(self) -> List[Tuple[Message, Any]]:
return [(Message(input(self._prompt_request)), self._ctx_id)]
async def _get_updates(self) -> List[Tuple[Any, Message]]:
return [(self._ctx_id, Message(input(self._prompt_request)))]

def _respond(self, responses: List[Context]):
print(f"{self._prompt_response}{responses[0].last_response.text}", file=self._descriptor)
async def _respond(self, ctx_id, last_response: Message):
print(f"{self._prompt_response}{last_response.text}", file=self._descriptor)

async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
async def connect(self, *args, **kwargs):
"""
The CLIProvider generates new dialog id used to user identification on each `connect` call.

Expand All @@ -46,4 +41,4 @@ async def connect(self, pipeline_runner: PipelineRunnerFunction, **kwargs):
self._ctx_id = uuid4()
if self._intro is not None:
print(self._intro)
await super().connect(pipeline_runner, **kwargs)
await super().connect(*args, **kwargs)
36 changes: 35 additions & 1 deletion chatsky/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import asyncio
import signal
import logging
from typing import Union, List, Dict, Optional, Hashable, Callable

Expand Down Expand Up @@ -102,6 +103,8 @@ def __init__(
parallelize_processing: bool = False,
):
self.actor: Actor = None
self.stopped_by_signal = False
self.context_lock = ContextLock()
self.messenger_interface = CLIMessengerInterface() if messenger_interface is None else messenger_interface
self.context_storage = {} if context_storage is None else context_storage
self.slots = GroupSlot.model_validate(slots) if slots is not None else None
Expand Down Expand Up @@ -347,6 +350,15 @@ async def _run_pipeline(

return ctx

def sigint_handler(self, loop):
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
self.stopped_by_signal = True
print("_sigint_handler() called")
# asyncio.run(asyncio.gather(*[iface.shutdown() for iface in self.messenger_interfaces]))
if self.messenger_interface.running_in_foreground:
loop.run_until_complete(self.messenger_interface.shutdown())
# In case someone launched a pipeline with connect() instead of run_in_foreground(), all SIGINTs will be ignored, though the flag self.stopped_by_signal is still changed to True.
logger.info(f"pipeline received SIGINT - stopping pipeline and all interfaces")
ZergLev marked this conversation as resolved.
Show resolved Hide resolved

def run(self):
"""
Method that starts a pipeline and connects to `messenger_interface`.
Expand All @@ -355,7 +367,18 @@ def run(self):
This method can be both blocking and non-blocking. It depends on current `messenger_interface` nature.
Message interfaces that run in a loop block current thread.
"""
asyncio.run(self.messenger_interface.connect(self._run_pipeline))

# event_loop = asyncio.get_event_loop()
# event_loop.add_signal_handler(signal.SIGINT, self._sigint_handler)

# This doesn't work for now, because _sigint_handler is just added to the queue of async tasks, waiting for the program, which it shouldn't, in order to shut it down at all.
# I'm using a different solution fow now, but the original one has the benefit of utilising the event loop (not ending other asyncio tasks) and "being thread-safe" according to some sources, not sure if that's true or needed, though.
# TO-DO: Do graceful termination via the event loop. I'm thinking if the _sigint_handler() task could be added to the start of the asyncio queue and not the end, it would've worked. But I know neither if that'll work nor how to do it.

# signal.signal(signal.SIGINT, self.sigint_handler)

asyncio.run(self.messenger_interface.run_in_foreground(self, self._run_pipeline))
ZergLev marked this conversation as resolved.
Show resolved Hide resolved
logger.info(f"pipeline finished working")

def __call__(
self, request: Message, ctx_id: Optional[Hashable] = None, update_ctx_misc: Optional[dict] = None
Expand All @@ -372,3 +395,14 @@ def __call__(
@property
def script(self) -> Script:
return self.actor.script


class ContextLock:
# locks: dict[ctx_id, asyncio.Lock] = {}
def __init__(self):
self.locks = {}

def __getitem__(self, key):
if not key in self.locks:
self.locks[key] = asyncio.Lock()
return self.locks[key]
Empty file.
Loading
Loading