Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling task retries in task SDK + execution API #45106

Merged
merged 11 commits into from
Dec 30, 2024
3 changes: 3 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class TITerminalStatePayload(BaseModel):
end_date: UtcDateTime
"""When the task completed executing"""

"""Indicates if the task should retry before failing or not."""
should_retry: bool = False


class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""
Expand Down
30 changes: 28 additions & 2 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,22 @@ def ti_update_state(
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude_unset=True)
# We do not need to deserialize "should_retry" -- it is used for dynamic decision-making within failed state
data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"should_retry"})

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
updated_state = ti_patch_payload.state
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
updated_state = State.FAILED
task_instance = session.get(TI, ti_id_str)
if _is_eligible_to_retry(task_instance, ti_patch_payload.should_retry):
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
query = query.values(state=State.UP_FOR_RETRY)
updated_state = State.UP_FOR_RETRY
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -359,3 +364,24 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(task_instance: TI, should_retry: bool) -> bool:
"""
Is task instance is eligible for retry.

:param task_instance: the task instance

:meta private:
"""
if not should_retry:
return False

if task_instance.state == State.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True

# max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for
# retries from the task SDK now, we can handle using max_tries
return task_instance.max_tries and task_instance.try_number <= task_instance.max_tries
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def fail(self, id: uuid.UUID, when: datetime, should_retry: bool):
"""Tell the API server that this TI has to fail, with or without retries."""
body = TITerminalStatePayload(end_date=when, state=TerminalTIState.FAILED, should_retry=should_retry)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
Expand Down
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,4 @@ class TITerminalStatePayload(BaseModel):

state: TerminalTIState
end_date: Annotated[datetime, Field(title="End Date")]
should_retry: Annotated[bool | None, Field(title="Should Retry")] = False
17 changes: 16 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,25 @@ class TaskState(BaseModel):
- anything else = FAILED
"""

state: TerminalTIState
state: Literal[TerminalTIState.SUCCESS, TerminalTIState.REMOVED, TerminalTIState.SKIPPED]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"


class FailState(BaseModel):
"""
Update a task's state to FAILED.

Contains attributes specific to FAILING a state like
ability to retry.
"""

should_retry: bool = True
end_date: datetime | None = None
state: Literal[TerminalTIState.FAILED] = TerminalTIState.FAILED
type: Literal["FailState"] = "FailState"


class DeferTask(TIDeferredStatePayload):
"""Update a task instance state to deferred."""

Expand Down Expand Up @@ -232,6 +246,7 @@ class SetRenderedFields(BaseModel):
ToSupervisor = Annotated[
Union[
TaskState,
FailState,
GetXCom,
GetConnection,
GetVariable,
Expand Down
15 changes: 13 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
FailState,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -293,6 +294,9 @@ class WatchedSubprocess:
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)
# denotes if a request to `fail` has been sent from the _handle_requests or not, or it will be handled in wait()
# useful to synchronise the API requests for `fail` between handle_requests and wait
_fail_request_sent: bool = attrs.field(default=False, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand Down Expand Up @@ -521,7 +525,7 @@ def wait(self) -> int:
# to reflect the final state of the process.
# For states like `deferred`, the process will exit with 0, but the state will be updated
# by the subprocess in the `handle_requests` method.
if self.final_state in TerminalTIState:
if (not self._fail_request_sent) and self.final_state in TerminalTIState:
self.client.task_instances.finish(
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
Expand Down Expand Up @@ -710,7 +714,14 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
log.debug("Received message from task runner", msg=msg)
resp = None
if isinstance(msg, TaskState):
if isinstance(msg, FailState):
self._terminal_state = TerminalTIState.FAILED
self._task_end_time_monotonic = time.monotonic()
self._fail_request_sent = True
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
log.debug("IN SIDE FAILSTATE.")
self.client.task_instances.fail(self.id, datetime.now(tz=timezone.utc), msg.should_retry)
elif isinstance(msg, TaskState):
log.debug("IN SIDE TaskState.")
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
Expand Down
11 changes: 6 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
FailState,
GetXCom,
RescheduleTask,
SetRenderedFields,
Expand Down Expand Up @@ -408,9 +409,10 @@ def run(ti: RuntimeTaskInstance, log: Logger):

# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
msg = FailState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
should_retry=False,
)

# TODO: Run task failure callbacks here
Expand All @@ -421,17 +423,16 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# External state updates are already handled with `ti_heartbeat` and will be
# updated already be another UI API. So, these exceptions should ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
msg = TaskState(
msg = FailState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
should_retry=False,
)
# TODO: Run task failure callbacks here
except SystemExit:
...
except BaseException:
# TODO: Handle TI handle failure
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
raise

msg = FailState(should_retry=True, end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)

Expand Down
20 changes: 19 additions & 1 deletion task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
resp = client.task_instances.start(ti_id, 100, start_date)
assert resp == ti_context

@pytest.mark.parametrize("state", [state for state in TerminalTIState])
@pytest.mark.parametrize("state", [state for state in TerminalTIState if state != TerminalTIState.FAILED])
def test_task_instance_finish(self, state):
# Simulate a successful response from the server that finishes (moved to terminal state) a task
ti_id = uuid6.uuid7()
Expand All @@ -139,6 +139,24 @@ def handle_request(request: httpx.Request) -> httpx.Response:
client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z")

def test_task_instance_fail(self):
# Simulate a successful response from the server that fails a task with retry.
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
assert actual_body["state"] == TerminalTIState.FAILED
assert actual_body["should_retry"] is True
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.fail(ti_id, when="2024-10-31T12:00:00Z", should_retry=True)

def test_task_instance_heartbeat(self):
# Simulate a successful response from the server that sends a heartbeat for a ti
ti_id = uuid6.uuid7()
Expand Down
18 changes: 18 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
FailState,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -884,6 +885,19 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
# testing to see if supervisor can handle FailState message
pytest.param(
FailState(
state=TerminalTIState.FAILED,
end_date=timezone.parse("2024-10-31T12:00:00Z"),
should_retry=False,
),
b"",
"task_instances.fail",
(TI_ID, timezone.parse("2024-11-7T12:00:00Z"), False),
"",
id="patch_task_instance_to_failed",
),
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
b"",
Expand All @@ -903,6 +917,7 @@ def test_handle_requests(
client_attr_path,
method_arg,
mock_response,
time_machine,
):
"""
Test handling of different messages to the subprocess. For any new message type, add a
Expand All @@ -916,6 +931,9 @@ def test_handle_requests(
4. Verifies that the response is correctly decoded.
"""

instant = tz.datetime(2024, 11, 7, 12, 0, 0, 0)
time_machine.move_to(instant, tick=False)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved

# Mock the client method. E.g. `client.variables.get` or `client.connections.get`
mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client)
mock_client_method.return_value = mock_response
Expand Down
92 changes: 90 additions & 2 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
FailState,
GetConnection,
SetRenderedFields,
StartupDetails,
Expand Down Expand Up @@ -256,6 +257,91 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
)


@pytest.mark.parametrize(
"retries",
[None, 0, 3],
)
def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries):
"""Test running a basic task that raises a base exception."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="zero_division_error",
python_callable=lambda: 1 / 0,
)
if retries is not None:
task.retries = retries

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="zero_division_error",
dag_id="basic_dag_base_exception",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_base_exception", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=FailState(
should_retry=True,
state=TerminalTIState.FAILED,
end_date=instant,
),
log=mock.ANY,
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
task_id="templated_task",
bash_command="echo 'Logical date is {{ logical_date }}'",
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)
mocked_parse(what, "basic_templated_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
mock_supervisor_comms.get_message.return_value = what
startup()

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(
rendered_fields={
"bash_command": "echo 'Logical date is {{ logical_date }}'",
"cwd": None,
"env": None,
}
),
log=mock.ANY,
)


@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
Expand Down Expand Up @@ -349,7 +435,9 @@ def execute(self, context):
),
],
)
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
def test_run_basic_failed_without_retries(
time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context
):
"""Test running a basic task that marks itself as failed by raising exception."""

class CustomOperator(BaseOperator):
Expand Down Expand Up @@ -381,7 +469,7 @@ def execute(self, context):
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
msg=FailState(state=TerminalTIState.FAILED, end_date=instant, should_retry=False), log=mock.ANY
)


Expand Down
Loading