diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c1bf588c2bbd4..c9b60e303f734 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -110,6 +110,24 @@ class TIRescheduleStatePayload(BaseModel): end_date: UtcDateTime +class TIRetryStatePayload(BaseModel): + """Schema for updating TaskInstance to a up_for_retry state.""" + + state: Annotated[ + Literal[IntermediateTIState.UP_FOR_RETRY], + # Specify a default in the schema, but not in code, so Pydantic marks it as required. + WithJsonSchema( + { + "type": "string", + "enum": [IntermediateTIState.UP_FOR_RETRY], + "default": IntermediateTIState.UP_FOR_RETRY, + } + ), + ] + end_date: UtcDateTime + task_retries: int + + def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: """ Determine the discriminator key for TaskInstance state transitions. @@ -129,6 +147,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: return "deferred" elif state == TIState.UP_FOR_RESCHEDULE: return "up_for_reschedule" + elif state == TIState.UP_FOR_RETRY: + return "up_for_retry" return "_other_" @@ -140,6 +160,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: Annotated[TITargetStatePayload, Tag("_other_")], Annotated[TIDeferredStatePayload, Tag("deferred")], Annotated[TIRescheduleStatePayload, Tag("up_for_reschedule")], + Annotated[TIRetryStatePayload, Tag("up_for_retry")], ], Discriminator(ti_state_discriminator), ] diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 2184c2946b8ef..24f199ae7a308 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -35,6 +35,7 @@ TIEnterRunningPayload, TIHeartbeatInfo, TIRescheduleStatePayload, + TIRetryStatePayload, TIRunContext, TIStateUpdate, TITerminalStatePayload, @@ -167,6 +168,7 @@ def ti_run( status.HTTP_404_NOT_FOUND: {"description": "Task Instance not found"}, status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, + status.HTTP_400_BAD_REQUEST: {"description": "Not a valid state transition"}, }, ) def ti_update_state( @@ -252,6 +254,20 @@ def ti_update_state( query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) # clear the next_method and next_kwargs so that none of the retries pick them up query = query.values(state=State.UP_FOR_RESCHEDULE, next_method=None, next_kwargs=None) + elif isinstance(ti_patch_payload, TIRetryStatePayload): + task_instance = session.get(TI, ti_id_str) + if not _is_eligible_to_retry(task_instance, ti_patch_payload.task_retries): + log.error("Task Instance %s cannot be retried", ti_id_str) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "reason": "bad_request", + "message": "Task Instance is not eligible to retry", + }, + ) + query = update(TI).where(TI.id == ti_id_str) + query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) + query = query.values(state=State.UP_FOR_RETRY, next_method=None, next_kwargs=None) # TODO: Replace this with FastAPI's Custom Exception handling: # https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers try: @@ -354,3 +370,18 @@ def ti_put_rtif( _update_rtif(task_instance, put_rtif_payload, session) return {"message": "Rendered task instance fields successfully set"} + + +def _is_eligible_to_retry(task_instance, task_retries: int): + """ + Is task instance is eligible for retry. + + :param task_instance: the task instance + + :meta private: + """ + if task_instance.state == State.RESTARTING: + # If a task is RESTARTING state it is always eligible for retry + return True + + return task_retries and task_instance.try_number <= task_instance.max_tries diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 787dcf55ab818..e73f265a91cbc 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -36,6 +36,7 @@ TIEnterRunningPayload, TIHeartbeatInfo, TIRescheduleStatePayload, + TIRetryStatePayload, TIRunContext, TITerminalStatePayload, ValidationError as RemoteValidationError, @@ -49,7 +50,7 @@ if TYPE_CHECKING: from datetime import datetime - from airflow.sdk.execution_time.comms import RescheduleTask + from airflow.sdk.execution_time.comms import RescheduleTask, RetryTask from airflow.typing_compat import ParamSpec P = ParamSpec("P") @@ -146,6 +147,13 @@ def reschedule(self, id: uuid.UUID, msg: RescheduleTask): # Create a reschedule state payload from msg self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) + def retry(self, id: uuid.UUID, msg: RetryTask): + """Tell the API server that this TI wants to retry.""" + body = TIRetryStatePayload(**msg.model_dump(exclude_unset=True)) + + # Create a retry state payload from msg + self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) + def set_rtif(self, id: uuid.UUID, body: dict[str, str]) -> dict[str, bool]: """Set Rendered Task Instance Fields via the API server.""" self.client.put(f"task-instances/{id}/rtif", json=body) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 00187364c8669..f37a981ec1cf9 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -112,6 +112,16 @@ class TIRescheduleStatePayload(BaseModel): reschedule_date: Annotated[datetime, Field(title="Reschedule Date")] +class TIRetryStatePayload(BaseModel): + """ + Schema for updating TaskInstance to a up_for_retry state. + """ + + state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = "up_for_retry" + end_date: Annotated[datetime, Field(title="End Date")] + task_retries: Annotated[int, Field(title="Task Retries")] + + class TITargetStatePayload(BaseModel): """ Schema for updating TaskInstance to a target state, excluding terminal and running states. diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index e1a3ce034611f..8ba7227718aca 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -55,6 +55,7 @@ TerminalTIState, TIDeferredStatePayload, TIRescheduleStatePayload, + TIRetryStatePayload, TIRunContext, VariableResponse, XComResponse, @@ -122,6 +123,12 @@ class RescheduleTask(TIRescheduleStatePayload): type: Literal["RescheduleTask"] = "RescheduleTask" +class RetryTask(TIRetryStatePayload): + """Update a task instance state to up_for_retry.""" + + type: Literal["RetryTask"] = "RetryTask" + + class GetXCom(BaseModel): key: str dag_id: str @@ -200,6 +207,7 @@ class SetRenderedFields(BaseModel): SetXCom, SetRenderedFields, RescheduleTask, + RetryTask, ], Field(discriminator="type"), ] diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index bf2aa0778c3b2..056001ca58ab7 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -55,6 +55,7 @@ GetXCom, PutVariable, RescheduleTask, + RetryTask, SetXCom, StartupDetails, TaskState, @@ -702,6 +703,9 @@ def _handle_request(self, msg, log): elif isinstance(msg, RescheduleTask): self._terminal_state = IntermediateTIState.UP_FOR_RESCHEDULE self.client.task_instances.reschedule(self.id, msg) + elif isinstance(msg, RetryTask): + self._terminal_state = IntermediateTIState.UP_FOR_RETRY + self.client.task_instances.retry(self.id, msg) elif isinstance(msg, SetXCom): self.client.xcoms.set(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.value, msg.map_index) elif isinstance(msg, PutVariable): diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index ba4ed881039e7..a1a39c66acfbb 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -34,6 +34,7 @@ from airflow.sdk.execution_time.comms import ( DeferTask, RescheduleTask, + RetryTask, SetRenderedFields, StartupDetails, TaskState, @@ -296,8 +297,22 @@ def run(ti: RuntimeTaskInstance, log: Logger): ) # TODO: Run task failure callbacks here - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated): + except AirflowTaskTerminated: ... + except (AirflowTaskTimeout, AirflowException): + # Couldn't load the task, don't know number of retries, guess + if not getattr(ti, "task", None): + # Let us set the task_retries to default = 0 + msg = RetryTask( + end_date=datetime.now(tz=timezone.utc), + task_retries=0, + ) + else: + msg = RetryTask( + end_date=datetime.now(tz=timezone.utc), + # is `or 0` needed? + task_retries=ti.task.retries or 0, + ) except SystemExit: ... except BaseException: diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index be1d945b82e4c..131cf360e7aa9 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -46,6 +46,7 @@ GetXCom, PutVariable, RescheduleTask, + RetryTask, SetXCom, TaskState, VariableResult, @@ -794,6 +795,14 @@ def watched_subprocess(self, mocker): "", id="patch_task_instance_to_deferred", ), + pytest.param( + RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z"), task_retries=1), + b"", + "task_instances.retry", + (TI_ID, RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z"), task_retries=1)), + "", + id="patch_task_instance_to_retry", + ), pytest.param( RescheduleTask( reschedule_date=timezone.parse("2024-10-31T12:00:00Z"), diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 4ed5f8f1598f3..b5609452a42a8 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -340,6 +340,43 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan assert trs[0].map_index == -1 assert trs[0].duration == 129600 + def test_ti_update_state_to_retry(self, client, session, create_task_instance, time_machine): + """ + Test that tests if the transition to retry state is handled correctly. + """ + + instant = timezone.datetime(2024, 10, 30) + time_machine.move_to(instant, tick=False) + + ti = create_task_instance( + task_id="test_ti_update_state_to_retry", + state=State.RUNNING, + session=session, + ) + ti.start_date = instant + session.commit() + + payload = { + "state": "up_for_retry", + "end_date": DEFAULT_END_DATE.isoformat(), + # a running task moving to up_for_retry + "task_retries": 1, + } + + response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + tis = session.query(TaskInstance).all() + assert len(tis) == 1 + assert tis[0].state == TaskInstanceState.UP_FOR_RETRY + assert tis[0].next_method is None + assert tis[0].next_kwargs is None + assert tis[0].duration == 129600 + class TestTIHealthEndpoint: def setup_method(self):