diff --git a/warehouse/metrics_tools/compute/service.py b/warehouse/metrics_tools/compute/service.py index d47bdcf1..f3871911 100644 --- a/warehouse/metrics_tools/compute/service.py +++ b/warehouse/metrics_tools/compute/service.py @@ -38,22 +38,46 @@ logger.setLevel(logging.DEBUG) -class JobFailed(Exception): +class JobError(Exception): pass -class JobTasksFailed(JobFailed): +class JobFailed(JobError): exceptions: t.List[Exception] + cancellations: t.List[str] failures: int - def __init__(self, job_id: str, failures: int, exceptions: t.List[Exception]): + def __init__( + self, + job_id: str, + failures: int, + exceptions: t.List[Exception], + cancellations: t.List[str], + ): self.failures = failures self.exceptions = exceptions + self.cancellations = cancellations super().__init__( - f"job[{job_id}] failed with {failures} failures and {len(exceptions)} exceptions" + f"job[{job_id}] failed with {failures} failures and {len(exceptions)} exceptions and {len(cancellations)} cancellations" ) +class JobTaskCancelled(JobError): + task_id: str + + def __init__(self, task_id: str): + self.task_id = task_id + super().__init__(f"task {task_id} was cancelled") + + +class JobTaskFailed(JobError): + exception: Exception + + def __init__(self, exception: Exception): + self.exception = exception + super().__init__(f"task failed with exception: {exception}") + + class MetricsCalculationService: id: str gcs_bucket: str @@ -153,10 +177,15 @@ async def _handle_query_job_submit_request( self.logger.warning("job[{job_id}] batch count mismatch") exceptions = [] + cancellations = [] for next_task in asyncio.as_completed(tasks): try: await next_task + except JobTaskCancelled as e: + cancellations.append(e.task_id) + except JobTaskFailed as e: + exceptions.append(e.exception) except Exception as e: self.logger.error( f"job[{job_id}] task failed with uncaught exception: {e}" @@ -168,8 +197,8 @@ async def _handle_query_job_submit_request( # If there are any exceptions then we report those as failed and short # circuit this method - if len(exceptions) > 0: - raise JobTasksFailed(job_id, len(exceptions), exceptions) + if len(exceptions) > 0 or len(cancellations) > 0: + raise JobFailed(job_id, len(exceptions), exceptions, cancellations) # Import the final result into the database self.logger.info("job[{job_id}]: importing final result into the database") @@ -246,9 +275,11 @@ async def _submit_query_task_to_scheduler( except CancelledError as e: self.logger.error(f"job[{job_id}] task cancelled {e.args}") await self._notify_job_task_cancelled(job_id, task_id) + raise JobTaskCancelled(task_id) except Exception as e: self.logger.error(f"job[{job_id}] task failed with exception: {e}") await self._notify_job_task_failed(job_id, task_id, e) + raise JobTaskFailed(e) return task_id async def close(self):