Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to cancel sending a task using middleware #327

Open
Bohdan-Ilchyshyn opened this issue May 20, 2024 · 0 comments
Open

How to cancel sending a task using middleware #327

Bohdan-Ilchyshyn opened this issue May 20, 2024 · 0 comments

Comments

@Bohdan-Ilchyshyn
Copy link

Bohdan-Ilchyshyn commented May 20, 2024

I create singleton middleware.
It checks whether such a task already exists and, if so, should cancel its sending in the pre_send func.
How to do it correctly? Return None or raise exception?

Middleware code

import inspect
import time
from hashlib import md5
from typing import Any, Coroutine, Union

from cashews import cache
from loguru import logger
from orjson import orjson
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult


class SingletonMiddleware(TaskiqMiddleware):
    SINGLETON_LABEL = "singleton"
    UNIQUE_ON_LABEL = "unique_on"
    LOCK_EXPIRE_LABEL = "lock_expire"
    KEY_PREFIX = "TKQ_SINGLETON_LOCK_"

    def __init__(
            self,
            default_lock_expire: int = 60,
    ) -> None:
        super().__init__()
        self.default_lock_expire = default_lock_expire

    def pre_send(
        self,
        message: "TaskiqMessage",
    ) -> "Union[TaskiqMessage, Coroutine[Any, Any, TaskiqMessage]]":
        if self.is_singleton_task(message):
            return self.lock_and_run(message)
        else:
            return message

    async def post_execute(
        self,
        message: "TaskiqMessage",
        result: "TaskiqResult[Any]",
    ) -> "Union[None, Coroutine[Any, Any, None]]":
        if self.is_singleton_task(message):
            await self.release_lock(message)
        return None

    async def on_error(
        self,
        message: "TaskiqMessage",
        result: "TaskiqResult[Any]",
        exception: BaseException,
    ) -> "Union[None, Coroutine[Any, Any, None]]":
        if self.is_singleton_task(message):
            await self.release_lock(message)
        return None

    def is_singleton_task(self, message: "TaskiqMessage") -> bool:
        return self.SINGLETON_LABEL in message.labels

    @staticmethod
    async def unlock(lock_key: str, task_id: str) -> bool:
        return await cache.unlock(lock_key, task_id)

    @staticmethod
    async def lock(lock_key: str, task_id: str, expire: int) -> bool:
        return await cache.set_lock(key=lock_key, value=task_id, expire=expire)

    @staticmethod
    async def locked(lock_key: str) -> bool:
        return await cache.is_locked(key=lock_key)

    @staticmethod
    async def get_existing_task_id(lock_key: str) -> int:
        return await cache.get(key=lock_key)

    async def lock_and_run(self, message: TaskiqMessage) -> TaskiqMessage | None:
        lock_acquired = await self.acquire_lock(message)

        if lock_acquired:
            return message
        else:
            lock_key = self.generate_lock(message)
            existing_task_id = self.get_existing_task_id(lock_key)
            logger.warning(f"Attempted to queue a duplicate of task ID {existing_task_id}")
            # raise SendTaskError()
            return None

    async def get_lock_expire(self, message: "TaskiqMessage") -> int:
        if self.LOCK_EXPIRE_LABEL in message.labels:
            return message.labels[self.LOCK_EXPIRE_LABEL]
        elif 'timeout' in message.labels:
            task_timeout = int(message.labels['timeout'])
            task_timeout += 5 * 60
            return task_timeout
        else:
            return self.default_lock_expire

    async def release_lock(self, message: "TaskiqMessage") -> bool:
        lock_key = self.generate_lock(message)
        unlocked = await self.unlock(lock_key, message.task_id)
        return unlocked

    async def acquire_lock(self, message: "TaskiqMessage") -> bool:
        lock_key = self.generate_lock(message)
        lock_expire = await self.get_lock_expire(message)
        locked = await self.lock(lock_key, message.task_id, lock_expire)
        return locked

    @staticmethod
    def generate_lock_key(task_name: str, task_args: list, task_kwargs: dict, key_prefix: str) -> str:
        str_args = str(orjson.dumps(task_args, option=orjson.OPT_SORT_KEYS))
        str_kwargs = str(orjson.dumps(task_kwargs, option=orjson.OPT_SORT_KEYS))
        task_hash = md5((task_name + str_args + str_kwargs).encode()).hexdigest()
        return key_prefix + task_hash

    def generate_lock(self, message: "TaskiqMessage") -> str:
        task = self.broker.find_task(message.task_name)

        if unique_on := message.labels.get('unique_on'):
            if isinstance(unique_on, str):
                unique_on = [unique_on]

            sig = inspect.signature(task.original_func)
            bound = sig.bind(*message.args, **message.kwargs).arguments

            unique_args = []
            unique_kwargs = {key: bound[key] for key in unique_on}

        else:
            unique_args = message.args
            unique_kwargs = message.kwargs

        lock_key = self.generate_lock_key(
            task_name=str(message.task_name),
            task_args=unique_args,
            task_kwargs=unique_kwargs,
            key_prefix=self.KEY_PREFIX,
        )

        return lock_key

Task example

@broker.task(
    singleton=True,
    unique_on=['id', 'name']
)
async def my_singleton_task(id: str, name: str) -> None:
    pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant