Skip to content

Commit

Permalink
Merge pull request #3824 from bjester/hotfixes
Browse files Browse the repository at this point in the history
Fix task duplication when kwargs contain UUIDs
  • Loading branch information
bjester committed Nov 17, 2022
2 parents 12e4827 + 3dbc390 commit 511865a
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
25 changes: 25 additions & 0 deletions contentcuration/contentcuration/tests/test_asynctask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import threading
import uuid

from celery import states
from celery.result import allow_join_result
Expand Down Expand Up @@ -205,6 +206,30 @@ def test_fetch_or_enqueue_task(self):
async_result = test_task.fetch_or_enqueue(self.user, is_test=True)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id__hex(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id.hex)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id.hex)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id__hex_then_uuid(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id.hex)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_fetch_or_enqueue_task__channel_id__uuid_then_hex(self):
channel_id = uuid.uuid4()
expected_task = test_task.enqueue(self.user, channel_id=channel_id)
async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id.hex)
self.assertEqual(expected_task.task_id, async_result.task_id)

def test_requeue_task(self):
existing_task_ids = requeue_test_task.find_ids()
self.assertEqual(len(existing_task_ids), 0)
Expand Down
26 changes: 18 additions & 8 deletions contentcuration/contentcuration/utils/celery/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ def fetch(self, task_id):
"""
return self.AsyncResult(task_id)

def fetch_match(self, task_id, **kwargs):
def _fetch_match(self, task_id, **kwargs):
"""
Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs
Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs and
assumes that kwargs is has been decoded from an prepared form
:param task_id: The hex task ID
:param kwargs: The kwargs the task was called with, which must match when fetching
:return: A CeleryAsyncResult
Expand All @@ -160,6 +161,12 @@ def fetch_match(self, task_id, **kwargs):
return async_result
return None

def _prepare_kwargs(self, kwargs):
return self.backend.encode({
key: value.hex if isinstance(value, uuid.UUID) else value
for key, value in kwargs.items()
})

def enqueue(self, user, **kwargs):
"""
Enqueues the task called with `kwargs`, and requires the user who wants to enqueue it. If `channel_id` is
Expand All @@ -176,14 +183,16 @@ def enqueue(self, user, **kwargs):
raise TypeError("All tasks must be assigned to a user.")

task_id = uuid.uuid4().hex
channel_id = kwargs.get("channel_id")
prepared_kwargs = self._prepare_kwargs(kwargs)
transcoded_kwargs = self.backend.decode(prepared_kwargs)
channel_id = transcoded_kwargs.get("channel_id")

logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {kwargs}")
logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {prepared_kwargs}")

# returns a CeleryAsyncResult
async_result = self.apply_async(
task_id=task_id,
kwargs=kwargs,
kwargs=transcoded_kwargs,
)

# ensure the result is saved to the backend (database)
Expand All @@ -192,7 +201,7 @@ def enqueue(self, user, **kwargs):
# after calling apply, we should have task result model, so get it and set our custom fields
task_result = get_task_model(self, task_id)
task_result.task_name = self.name
task_result.task_kwargs = self.backend.encode_content(kwargs)[2]
task_result.task_kwargs = prepared_kwargs
task_result.user = user
task_result.channel_id = channel_id
task_result.save()
Expand All @@ -211,9 +220,10 @@ def fetch_or_enqueue(self, user, **kwargs):
# if we're eagerly executing the task (synchronously), then we shouldn't check for an existing task because
# implementations probably aren't prepared to rely on an existing asynchronous task
if not self.app.conf.task_always_eager:
task_ids = self.find_incomplete_ids(**kwargs).order_by("date_created")[:1]
transcoded_kwargs = self.backend.decode(self._prepare_kwargs(kwargs))
task_ids = self.find_incomplete_ids(**transcoded_kwargs).order_by("date_created")[:1]
if task_ids:
async_result = self.fetch_match(task_ids[0], **kwargs)
async_result = self._fetch_match(task_ids[0], **transcoded_kwargs)
if async_result:
logging.info(f"Fetched matching task {self.name} for user {user.pk} with id {async_result.id} | {kwargs}")
return async_result
Expand Down

0 comments on commit 511865a

Please sign in to comment.