Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling task retries in task SDK + execution API #45106

Merged
merged 11 commits into from
Dec 30, 2024
2 changes: 2 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class TITerminalStatePayload(BaseModel):
end_date: UtcDateTime
"""When the task completed executing"""

task_retries: int | None = None
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved


class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""
Expand Down
29 changes: 27 additions & 2 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,22 @@ def ti_update_state(
)

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude_unset=True)
# We do not need to deserialize "task_retries" -- it is used for dynamic decision making within failed state
data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"task_retries"})

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
updated_state = ti_patch_payload.state
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
updated_state = State.FAILED
task_instance = session.get(TI, ti_id_str)
if _is_eligible_to_retry(task_instance, ti_patch_payload.task_retries):
query = query.values(state=State.UP_FOR_RETRY)
updated_state = State.UP_FOR_RETRY
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -359,3 +364,23 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(task_instance: TI, task_retries: int | None):
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
"""
Is task instance is eligible for retry.

:param task_instance: the task instance

:meta private:
"""
if task_instance.state == State.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True

if task_retries == -1:
# task_runner indicated that it doesn't know number of retries, guess it from the table
return task_instance.try_number <= task_instance.max_tries

return task_retries and task_instance.try_number <= task_instance.max_tries
5 changes: 4 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,14 @@ def start(self, id: uuid.UUID, pid: int, when: datetime) -> TIRunContext:
resp = self.client.patch(f"task-instances/{id}/run", content=body.model_dump_json())
return TIRunContext.model_validate_json(resp.read())

def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime, task_retries: int | None):
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))

if task_retries:
body.task_retries = task_retries

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
Expand Down
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,4 @@ class TITerminalStatePayload(BaseModel):

state: TerminalTIState
end_date: Annotated[datetime, Field(title="End Date")]
task_retries: Annotated[int | None, Field(title="Task Retries")] = None
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class TaskState(BaseModel):

state: TerminalTIState
end_date: datetime | None = None
task_retries: int | None = None
type: Literal["TaskState"] = "TaskState"


Expand Down
16 changes: 14 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ class WatchedSubprocess:
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)
# denotes if a task `has` retries defined or not, helpful to send signals between the handle_requests and wait
_should_retry: bool = attrs.field(default=False, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand Down Expand Up @@ -518,13 +520,15 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1

print("The exit code is", self._exit_code)

amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
# If the process has finished in a terminal state, update the state of the TaskInstance
# to reflect the final state of the process.
# For states like `deferred`, the process will exit with 0, but the state will be updated
# by the subprocess in the `handle_requests` method.
if self.final_state in TerminalTIState:
if self.final_state in TerminalTIState and not self._should_retry:
self.client.task_instances.finish(
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc)
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc), task_retries=None
)
return self._exit_code

Expand Down Expand Up @@ -714,6 +718,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
if msg.task_retries:
self.client.task_instances.finish(
id=self.id,
state=self.final_state,
when=datetime.now(tz=timezone.utc),
task_retries=msg.task_retries,
)
self._should_retry = True
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
Expand Down
9 changes: 7 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,13 @@ def run(ti: RuntimeTaskInstance, log: Logger):
except SystemExit:
...
except BaseException:
# TODO: Handle TI handle failure
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
raise
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if not getattr(ti, "task", None):
# We do not know about retries, let's mark it -1, so that the execution api can make a guess
msg.task_retries = -1
else:
# `None` indicates no retries provided, the default is anyway 0 which evaluates to false
msg.task_retries = ti.task.retries or None

if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
Expand Down
22 changes: 21 additions & 1 deletion task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,27 @@ def handle_request(request: httpx.Request) -> httpx.Response:
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z")
client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z", task_retries=None)

def test_task_instance_finish_with_retries(self):
# Simulate a successful response from the server that finishes (moved to terminal state) a task when retries are present
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
assert actual_body["state"] == TerminalTIState.FAILED
assert actual_body["task_retries"] == 2
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.finish(
ti_id, state=TerminalTIState.FAILED, when="2024-10-31T12:00:00Z", task_retries=2
)

def test_task_instance_heartbeat(self):
# Simulate a successful response from the server that sends a heartbeat for a ti
Expand Down
13 changes: 13 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,19 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
# checking if we are capable of handling if task_retries is passed
pytest.param(
TaskState(
state=TerminalTIState.FAILED,
end_date=timezone.parse("2024-10-31T12:00:00Z"),
task_retries=2,
),
b"",
"",
(),
"",
id="patch_task_instance_to_failed_with_retries",
),
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
b"",
Expand Down
118 changes: 118 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,124 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
)


@pytest.mark.parametrize(
["retries", "expected_msg"],
[
# No retries configured
pytest.param(None, TaskState(state=TerminalTIState.FAILED, task_retries=None)),
# Retries configured
pytest.param(2, TaskState(state=TerminalTIState.FAILED, task_retries=2)),
# Retries configured but with 0
pytest.param(0, TaskState(state=TerminalTIState.FAILED, task_retries=None)),
],
)
def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries, expected_msg):
"""Test running a basic task that raises a base exception."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="zero_division_error",
retries=retries,
python_callable=lambda: 1 / 0,
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="zero_division_error",
dag_id="basic_dag_base_exception",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_base_exception", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())
expected_msg.end_date = instant
mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_msg, log=mock.ANY)


def test_run_raises_missing_task(time_machine, mocked_parse, make_ti_context):
"""Test running a basic dag with missing ti.task."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="missing_task",
python_callable=lambda: 1 / 0,
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="missing_task", dag_id="basic_dag_missing_task", run_id="c", try_number=1
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_missing_task", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
# set ti.task as None
ti.task = None
run(ti, log=mock.MagicMock())
mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, task_retries=-1, end_date=instant), log=mock.ANY
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
task_id="templated_task",
bash_command="echo 'Logical date is {{ logical_date }}'",
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)
mocked_parse(what, "basic_templated_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
mock_supervisor_comms.get_message.return_value = what
startup()

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(
rendered_fields={
"bash_command": "echo 'Logical date is {{ logical_date }}'",
"cwd": None,
"env": None,
}
),
log=mock.ANY,
)


@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
Expand Down
65 changes: 65 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,71 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].map_index == -1
assert trs[0].duration == 129600

@pytest.mark.parametrize(
("retries", "expected_state"),
[
# retries given
(2, State.UP_FOR_RETRY),
# retries not given
(None, State.FAILED),
# retries given but as 0
(0, State.FAILED),
# retries not known, given as -1, calculates on table default
(-1, State.UP_FOR_RETRY),
],
)
def test_ti_update_state_to_retry(self, client, session, create_task_instance, retries, expected_state):
ti = create_task_instance(
task_id="test_ti_update_state_to_retry",
state=State.RUNNING,
)
ti.retries = retries
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": State.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
"task_retries": retries,
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == expected_state
assert ti.next_method is None
assert ti.next_kwargs is None

def test_ti_update_state_to_retry_when_restarting(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_update_state_to_retry_when_restarting",
state=State.RESTARTING,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": State.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == State.UP_FOR_RETRY
assert ti.next_method is None
assert ti.next_kwargs is None


class TestTIHealthEndpoint:
def setup_method(self):
Expand Down