diff --git a/contentcuration/contentcuration/tests/test_asynctask.py b/contentcuration/contentcuration/tests/test_asynctask.py index 58017559ee..85b40b6237 100644 --- a/contentcuration/contentcuration/tests/test_asynctask.py +++ b/contentcuration/contentcuration/tests/test_asynctask.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import threading +import uuid from celery import states from celery.result import allow_join_result @@ -205,6 +206,30 @@ def test_fetch_or_enqueue_task(self): async_result = test_task.fetch_or_enqueue(self.user, is_test=True) self.assertEqual(expected_task.task_id, async_result.task_id) + def test_fetch_or_enqueue_task__channel_id(self): + channel_id = uuid.uuid4() + expected_task = test_task.enqueue(self.user, channel_id=channel_id) + async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id) + self.assertEqual(expected_task.task_id, async_result.task_id) + + def test_fetch_or_enqueue_task__channel_id__hex(self): + channel_id = uuid.uuid4() + expected_task = test_task.enqueue(self.user, channel_id=channel_id.hex) + async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id.hex) + self.assertEqual(expected_task.task_id, async_result.task_id) + + def test_fetch_or_enqueue_task__channel_id__hex_then_uuid(self): + channel_id = uuid.uuid4() + expected_task = test_task.enqueue(self.user, channel_id=channel_id.hex) + async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id) + self.assertEqual(expected_task.task_id, async_result.task_id) + + def test_fetch_or_enqueue_task__channel_id__uuid_then_hex(self): + channel_id = uuid.uuid4() + expected_task = test_task.enqueue(self.user, channel_id=channel_id) + async_result = test_task.fetch_or_enqueue(self.user, channel_id=channel_id.hex) + self.assertEqual(expected_task.task_id, async_result.task_id) + def test_requeue_task(self): existing_task_ids = requeue_test_task.find_ids() self.assertEqual(len(existing_task_ids), 0) diff --git a/contentcuration/contentcuration/utils/celery/tasks.py b/contentcuration/contentcuration/utils/celery/tasks.py index c006a79c98..0d066d3329 100644 --- a/contentcuration/contentcuration/utils/celery/tasks.py +++ b/contentcuration/contentcuration/utils/celery/tasks.py @@ -146,9 +146,10 @@ def fetch(self, task_id): """ return self.AsyncResult(task_id) - def fetch_match(self, task_id, **kwargs): + def _fetch_match(self, task_id, **kwargs): """ - Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs + Gets the result object for a task, assuming it was called async, and ensures it was called with kwargs and + assumes that kwargs is has been decoded from an prepared form :param task_id: The hex task ID :param kwargs: The kwargs the task was called with, which must match when fetching :return: A CeleryAsyncResult @@ -160,6 +161,12 @@ def fetch_match(self, task_id, **kwargs): return async_result return None + def _prepare_kwargs(self, kwargs): + return self.backend.encode({ + key: value.hex if isinstance(value, uuid.UUID) else value + for key, value in kwargs.items() + }) + def enqueue(self, user, **kwargs): """ Enqueues the task called with `kwargs`, and requires the user who wants to enqueue it. If `channel_id` is @@ -176,14 +183,16 @@ def enqueue(self, user, **kwargs): raise TypeError("All tasks must be assigned to a user.") task_id = uuid.uuid4().hex - channel_id = kwargs.get("channel_id") + prepared_kwargs = self._prepare_kwargs(kwargs) + transcoded_kwargs = self.backend.decode(prepared_kwargs) + channel_id = transcoded_kwargs.get("channel_id") - logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {kwargs}") + logging.info(f"Enqueuing task:id {self.name}:{task_id} for user:channel {user.pk}:{channel_id} | {prepared_kwargs}") # returns a CeleryAsyncResult async_result = self.apply_async( task_id=task_id, - kwargs=kwargs, + kwargs=transcoded_kwargs, ) # ensure the result is saved to the backend (database) @@ -192,7 +201,7 @@ def enqueue(self, user, **kwargs): # after calling apply, we should have task result model, so get it and set our custom fields task_result = get_task_model(self, task_id) task_result.task_name = self.name - task_result.task_kwargs = self.backend.encode_content(kwargs)[2] + task_result.task_kwargs = prepared_kwargs task_result.user = user task_result.channel_id = channel_id task_result.save() @@ -211,9 +220,10 @@ def fetch_or_enqueue(self, user, **kwargs): # if we're eagerly executing the task (synchronously), then we shouldn't check for an existing task because # implementations probably aren't prepared to rely on an existing asynchronous task if not self.app.conf.task_always_eager: - task_ids = self.find_incomplete_ids(**kwargs).order_by("date_created")[:1] + transcoded_kwargs = self.backend.decode(self._prepare_kwargs(kwargs)) + task_ids = self.find_incomplete_ids(**transcoded_kwargs).order_by("date_created")[:1] if task_ids: - async_result = self.fetch_match(task_ids[0], **kwargs) + async_result = self._fetch_match(task_ids[0], **transcoded_kwargs) if async_result: logging.info(f"Fetched matching task {self.name} for user {user.pk} with id {async_result.id} | {kwargs}") return async_result