Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question: How to cancel a running task #305

Open
realitix opened this issue Mar 21, 2024 · 4 comments
Open

Question: How to cancel a running task #305

realitix opened this issue Mar 21, 2024 · 4 comments

Comments

@realitix
Copy link

Hello,
I have a special case to manage and I don't see how to do it. At a given moment, I need to know if a task (I have its ID) is actually in progress on a worker, is that possible?

@realitix
Copy link
Author

After further consideration, what I am looking for is the ability to stop an ongoing task. Is it possible ?

@s3rius
Copy link
Member

s3rius commented Apr 19, 2024

Currently there's no such functionality, but I really do want to define an interface to setup such task interruptors.

I'm open for discussion on that.

@realitix
Copy link
Author

I developed a custom receiver for that. If someone wants to do it with redis, here the code:

import asyncio
import uuid
from typing import Any, AsyncGenerator, cast

import anyio
from loguru import logger
from redis.asyncio import Redis
from taskiq.abc.broker import AckableMessage
from taskiq.message import BrokerMessage, TaskiqMessage
from taskiq.receiver.receiver import QUEUE_DONE, Receiver
from taskiq_redis import ListQueueBroker

# ruff: noqa: ANN401,BLE001,C901
# pylint: skip-file

CANCELLER_KEY = "__cancel_task_id__"


class CancellableListQueueBroker(ListQueueBroker):
    def __init__(
        self,
        *args: Any,
        queue_name_cancel: str = "taskiq_cancel",
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.queue_name_cancel = queue_name_cancel

    async def listen_canceller(self) -> AsyncGenerator[bytes, None]:
        async with Redis(connection_pool=self.connection_pool) as redis_conn:
            redis_pubsub_channel = redis_conn.pubsub()
            await redis_pubsub_channel.subscribe(self.queue_name_cancel)
            async for message in redis_pubsub_channel.listen():
                if not message:
                    continue
                if message["type"] != "message":
                    logger.debug("Received non-message from redis: {}", message)
                    continue
                yield message["data"]

    async def cancel_task(self, task_id: uuid.UUID) -> None:
        taskiq_message: TaskiqMessage = self._prepare_message(task_id)
        broker_message: BrokerMessage = self.formatter.dumps(taskiq_message)
        async with Redis(connection_pool=self.connection_pool) as redis_conn:
            await redis_conn.publish(self.queue_name_cancel, broker_message.message)

    def _prepare_message(self, task_id: uuid.UUID) -> TaskiqMessage:
        return TaskiqMessage(
            task_id=self.id_generator(),
            task_name="canceller",
            labels={},
            labels_types={},
            args=[],
            kwargs={CANCELLER_KEY: task_id.hex},
        )


class CancellableReceiver(Receiver):
    def __init__(
        self,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.tasks: set[asyncio.Task[Any]] = set()

    def parse_message(self, message: bytes | AckableMessage) -> TaskiqMessage | None:
        message_data = message.data if isinstance(message, AckableMessage) else message
        try:
            taskiq_msg = self.broker.formatter.loads(message=message_data)
            taskiq_msg.parse_labels()
        except Exception as exc:
            logger.warning(
                "Cannot parse message: %s. Skipping execution.\n %s",
                message_data,
                exc,
                exc_info=True,
            )
            return None
        return taskiq_msg

    async def listen(self) -> None:  # pragma: no cover
        if self.run_startup:
            await self.broker.startup()
        logger.info("Listening started.")
        queue: asyncio.Queue[bytes | AckableMessage] = asyncio.Queue()

        async with anyio.create_task_group() as gr:
            gr.start_soon(self.prefetcher, queue)
            gr.start_soon(self.runner, queue)
            gr.start_soon(self.runner_canceller)

        if self.on_exit is not None:
            self.on_exit(self)

    async def runner_canceller(
        self,
    ) -> None:
        def cancel_task(task_id: str) -> None:
            for task in self.tasks:
                if task.get_name() == task_id:
                    if task.cancel():
                        logger.info("Cancelling task {}", task_id)
                    else:
                        logger.warning("Cannot cancel task {}", task_id)

        iterator = cast(CancellableListQueueBroker, self.broker).listen_canceller()
        while True:
            try:
                message = await iterator.__anext__()
                taskiq_msg = self.parse_message(message)

                if not taskiq_msg:
                    continue

                if CANCELLER_KEY in taskiq_msg.kwargs:
                    cancel_task(taskiq_msg.kwargs[CANCELLER_KEY])
            except asyncio.CancelledError:
                break
            except StopAsyncIteration:
                break

    async def runner(
        self,
        queue: asyncio.Queue[bytes | AckableMessage],
    ) -> None:
        def task_cb(task: asyncio.Task[Any]) -> None:
            self.tasks.discard(task)
            if self.sem is not None:
                self.sem.release()

        while True:
            if self.sem is not None:
                await self.sem.acquire()

            self.sem_prefetch.release()
            message = await queue.get()
            if message is QUEUE_DONE:
                break

            taskiq_msg = self.parse_message(message)
            if not taskiq_msg:
                continue

            task = asyncio.create_task(
                self.callback(message=message, raise_err=False),
                name=str(taskiq_msg.task_id),
            )
            self.tasks.add(task)
            task.add_done_callback(task_cb)

@realitix realitix changed the title Question: How to know that a task is actually running Question: How to cancel a running task Apr 20, 2024
@metheoryt
Copy link

That functionality would be really nice to have inside taskiq by default

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants