Skip to content

Commit

Permalink
AIP-72: Handling task retries in task SDK + execution API
Browse files Browse the repository at this point in the history
closes: apache#44351
  • Loading branch information
amoghrajesh authored and kaxil committed Dec 27, 2024
1 parent 7f2b8ef commit ff1c5b3
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 19 deletions.
36 changes: 29 additions & 7 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TerminalTIState

# TODO: Add dependency on JWT token
router = AirflowRouter()
Expand Down Expand Up @@ -185,9 +185,13 @@ def ti_update_state(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state,) = session.execute(old).one()
(
previous_state,
try_number,
max_tries,
) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
Expand All @@ -205,11 +209,17 @@ def ti_update_state(

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
updated_state = ti_patch_payload.state
# if we get failed, we should attempt to retry, as it is a more
# normal state. Tasks with retries are more frequent than without retries.
if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY:
updated_state = State.FAILED
elif ti_patch_payload.state == State.FAILED:
if _is_eligible_to_retry(previous_state, try_number, max_tries):
updated_state = State.UP_FOR_RETRY
else:
updated_state = State.FAILED
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -359,3 +369,15 @@ def ti_put_rtif(
_update_rtif(task_instance, put_rtif_payload, session)

return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True

# max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for
# retries from the task SDK now, we can handle using max_tries
return max_tries != 0 and try_number <= max_tries
5 changes: 4 additions & 1 deletion airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ class TerminalTIState(str, Enum):
"""States that a Task Instance can be in that indicate it has reached a terminal state."""

SUCCESS = "success"
FAILED = "failed"
FAILED = "failed" # This state indicates that we attempt to retry.
SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped
REMOVED = "removed"
FAIL_WITHOUT_RETRY = (
"fail_without_retry" # This state is useful for when we want to terminate a task, without retrying.
)

def __str__(self) -> str:
return self.value
Expand Down
1 change: 0 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
"""Tell the API server that this TI has reached a terminal state."""
# TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing.
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))

self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
Expand Down
5 changes: 4 additions & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ class TerminalTIState(str, Enum):
"""

SUCCESS = "success"
FAILED = "failed"
FAILED = "failed" # This state indicates that we attempt to retry.
SKIPPED = "skipped"
REMOVED = "removed"
FAIL_WITHOUT_RETRY = (
"fail_without_retry" # This state is useful for when we want to terminate a task, without retrying.
)


class ValidationError(BaseModel):
Expand Down
9 changes: 4 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = TaskState(
state=TerminalTIState.FAILED,
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
)

Expand All @@ -434,16 +434,15 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# updated already be another UI API. So, these exceptions should ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
msg = TaskState(
state=TerminalTIState.FAILED,
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
)
# TODO: Run task failure callbacks here
except SystemExit:
...
except BaseException:
# TODO: Handle TI handle failure
raise

# TODO: Run task failure callbacks here
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)

Expand Down
16 changes: 16 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,18 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
# testing to see if supervisor can handle TaskState message with state as fail_with_retry
pytest.param(
TaskState(
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
b"",
"",
(),
"",
id="patch_task_instance_to_failed_with_retries",
),
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
b"",
Expand All @@ -903,6 +915,7 @@ def test_handle_requests(
client_attr_path,
method_arg,
mock_response,
time_machine,
):
"""
Test handling of different messages to the subprocess. For any new message type, add a
Expand All @@ -916,6 +929,9 @@ def test_handle_requests(
4. Verifies that the response is correctly decoded.
"""

instant = tz.datetime(2024, 11, 7, 12, 0, 0, 0)
time_machine.move_to(instant, tick=False)

# Mock the client method. E.g. `client.variables.get` or `client.connections.get`
mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client)
mock_client_method.return_value = mock_response
Expand Down
82 changes: 80 additions & 2 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,84 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context, mock_sup
)


def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context):
"""Test running a basic task that raises a base exception which should send fail_with_retry state."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="zero_division_error",
python_callable=lambda: 1 / 0,
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
task_id="zero_division_error",
dag_id="basic_dag_base_exception",
run_id="c",
try_number=1,
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)

ti = mocked_parse(what, "basic_dag_base_exception", task)

instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(
state=TerminalTIState.FAILED,
end_date=instant,
),
log=mock.ANY,
)


def test_startup_basic_templated_dag(mocked_parse, make_ti_context):
"""Test running a DAG with templated task."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
task_id="templated_task",
bash_command="echo 'Logical date is {{ logical_date }}'",
)

what = StartupDetails(
ti=TaskInstance(
id=uuid7(), task_id="templated_task", dag_id="basic_templated_dag", run_id="c", try_number=1
),
file="",
requests_fd=0,
ti_context=make_ti_context(),
)
mocked_parse(what, "basic_templated_dag", task)

with mock.patch(
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as mock_supervisor_comms:
mock_supervisor_comms.get_message.return_value = what
startup()

mock_supervisor_comms.send_request.assert_called_once_with(
msg=SetRenderedFields(
rendered_fields={
"bash_command": "echo 'Logical date is {{ logical_date }}'",
"cwd": None,
"env": None,
}
),
log=mock.ANY,
)


@pytest.mark.parametrize(
["task_params", "expected_rendered_fields"],
[
Expand Down Expand Up @@ -345,7 +423,7 @@ def execute(self, context):
),
],
)
def test_run_basic_failed(
def test_run_basic_failed_without_retries(
time_machine, mocked_parse, dag_id, task_id, fail_with_exception, make_ti_context, mock_supervisor_comms
):
"""Test running a basic task that marks itself as failed by raising exception."""
Expand Down Expand Up @@ -376,7 +454,7 @@ def execute(self, context):
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=TaskState(state=TerminalTIState.FAILED, end_date=instant), log=mock.ANY
msg=TaskState(state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=instant), log=mock.ANY
)


Expand Down
42 changes: 40 additions & 2 deletions tests/api_fastapi/execution_api/routes/test_task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.state import State, TaskInstanceState, TerminalTIState

from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields

Expand Down Expand Up @@ -234,7 +234,7 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
with mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
mock.Mock(one=lambda: ("running",)), # First call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0)), # First call returns "queued"
SQLAlchemyError("Database error"), # Second call raises an error
],
):
Expand Down Expand Up @@ -340,6 +340,44 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan
assert trs[0].map_index == -1
assert trs[0].duration == 129600

@pytest.mark.parametrize(
("retries", "expected_state"),
[
(0, State.FAILED),
(None, State.FAILED),
(3, State.UP_FOR_RETRY),
],
)
def test_ti_update_state_to_failed_with_retries(
self, client, session, create_task_instance, retries, expected_state
):
ti = create_task_instance(
task_id="test_ti_update_state_to_retry",
state=State.RUNNING,
)

if retries is not None:
ti.max_tries = retries
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": TerminalTIState.FAILED,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204
assert response.text == ""

session.expire_all()

ti = session.get(TaskInstance, ti.id)
assert ti.state == expected_state
assert ti.next_method is None
assert ti.next_kwargs is None

def test_ti_update_state_to_failed_table_check(self, client, session, create_task_instance):
ti = create_task_instance(
task_id="test_ti_update_state_to_failed_table_check",
Expand Down

0 comments on commit ff1c5b3

Please sign in to comment.