Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handling task retries in task SDK + execution API #45106

Merged
merged 11 commits into from
Dec 30, 2024
36 changes: 29 additions & 7 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TerminalTIState

# TODO: Add dependency on JWT token
router = AirflowRouter()
Expand Down Expand Up @@ -185,9 +185,13 @@ def ti_update_state(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state,) = session.execute(old).one()
(
previous_state,
try_number,
max_tries,
) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
Expand All @@ -205,11 +209,17 @@ def ti_update_state(

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
updated_state = ti_patch_payload.state
# if we get failed, we should attempt to retry, as it is a more
# normal state. Tasks with retries are more frequent than without retries.
if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
updated_state = State.FAILED
elif ti_patch_payload.state == State.FAILED:
if _is_eligible_to_retry(previous_state, try_number, max_tries):
updated_state = State.UP_FOR_RETRY
else:
updated_state = State.FAILED
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -359,3 +369,15 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved

# max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for
# retries from the task SDK now, we can handle using max_tries
return max_tries != 0 and try_number <= max_tries
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 4 additions & 1 deletion airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ class TerminalTIState(str, Enum):
"""States that a Task Instance can be in that indicate it has reached a terminal state."""

SUCCESS = "success"
FAILED = "failed"
FAILED = "failed" # This state indicates that we attempt to retry.
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped
REMOVED = "removed"
FAIL_WITHOUT_RETRY = (
"fail_without_retry" # This state is useful for when we want to terminate a task, without retrying.
)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self) -> str:
return self.value
Expand Down
1 change: 0 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
Expand Down
5 changes: 4 additions & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ class TerminalTIState(str, Enum):
"""

SUCCESS = "success"
FAILED = "failed"
FAILED = "failed" # This state indicates that we attempt to retry.
SKIPPED = "skipped"
REMOVED = "removed"
FAIL_WITHOUT_RETRY = (
"fail_without_retry" # This state is useful for when we want to terminate a task, without retrying.
)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved


class ValidationError(BaseModel):
Expand Down
9 changes: 4 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
state=TerminalTIState.FAILED,
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
)

Expand All @@ -422,16 +422,15 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# updated already be another UI API. So, these exceptions should ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
msg = TaskState(
state=TerminalTIState.FAILED,
state=TerminalTIState.FAIL_WITHOUT_RETRY,
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
end_date=datetime.now(tz=timezone.utc),
)
# TODO: Run task failure callbacks here
except SystemExit:
...
except BaseException:
# TODO: Handle TI handle failure
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
raise

# TODO: Run task failure callbacks here
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)

Expand Down
16 changes: 16 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,18 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
# testing to see if supervisor can handle TaskState message with state as fail_with_retry
pytest.param(
TaskState(
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
b"",
"",
(),
"",
id="patch_task_instance_to_failed_with_retries",
),
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
b"",
Expand All @@ -903,6 +915,7 @@ def test_handle_requests(
client_attr_path,
method_arg,
mock_response,
time_machine,
):
"""
Test handling of different messages to the subprocess. For any new message type, add a
Expand All @@ -916,6 +929,9 @@ def test_handle_requests(
4. Verifies that the response is correctly decoded.
"""

instant = tz.datetime(2024, 11, 7, 12, 0, 0, 0)
time_machine.move_to(instant, tick=False)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved

# Mock the client method. E.g. `client.variables.get` or `client.connections.get`
mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client)
mock_client_method.return_value = mock_response
Expand Down
84 changes: 82 additions & 2 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,84 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
)


def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context):
"""Test running a basic task that raises a base exception which should send fail_with_retry state."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="zero_division_error",
python_callable=lambda: 1 / 0,
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="zero_division_error",
dag_id="basic_dag_base_exception",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_base_exception", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(
state=TerminalTIState.FAILED,
end_date=instant,
),
log=mock.ANY,
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
task_id="templated_task",
bash_command="echo 'Logical date is {{ logical_date }}'",
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)
mocked_parse(what, "basic_templated_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
mock_supervisor_comms.get_message.return_value = what
startup()

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(
rendered_fields={
"bash_command": "echo 'Logical date is {{ logical_date }}'",
"cwd": None,
"env": None,
}
),
log=mock.ANY,
)


@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
Expand Down Expand Up @@ -349,7 +427,9 @@ def execute(self, context):
),
],
)
def test_run_basic_failed(time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context):
def test_run_basic_failed_without_retries(
time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context
):
"""Test running a basic task that marks itself as failed by raising exception."""

class CustomOperator(BaseOperator):
Expand Down Expand Up @@ -381,7 +461,7 @@ def execute(self, context):
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
msg=TaskState(state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=instant), log=mock.ANY
)


Expand Down
42 changes: 40 additions & 2 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import State, TaskInstanceState, TerminalTIState

from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields

Expand Down Expand Up @@ -234,7 +234,7 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
with mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
mock.Mock(one=lambda: ("running",)), # First call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0)), # First call returns "queued"
SQLAlchemyError("Database error"), # Second call raises an error
],
):
Expand Down Expand Up @@ -340,6 +340,44 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].map_index == -1
assert trs[0].duration == 129600

@pytest.mark.parametrize(
("retries", "expected_state"),
[
(0, State.FAILED),
(None, State.FAILED),
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
(3, State.UP_FOR_RETRY),
],
)
def test_ti_update_state_to_failed_with_retries(
self, client, session, create_task_instance, retries, expected_state
):
ti = create_task_instance(
task_id="test_ti_update_state_to_retry",
state=State.RUNNING,
)

if retries is not None:
ti.max_tries = retries
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": TerminalTIState.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == expected_state
assert ti.next_method is None
assert ti.next_kwargs is None

def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_update_state_to_failed_table_check",
Expand Down
Loading