Skip to content

Commit

Permalink
failed = fail after retrying, fail_without_retry = just fail
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh committed Dec 27, 2024
1 parent 873e765 commit 70c7c73
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 123 deletions.
3 changes: 0 additions & 3 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ class TITerminalStatePayload(BaseModel):
end_date: UtcDateTime
"""When the task completed executing"""

"""Indicates if the task should retry before failing or not."""
should_retry: bool = False


class TITargetStatePayload(BaseModel):
"""Schema for updating TaskInstance to a target state, excluding terminal and running states."""
Expand Down
46 changes: 21 additions & 25 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.utils import timezone
from airflow.utils.state import State
from airflow.utils.state import State, TerminalTIState

# TODO: Add dependency on JWT token
router = AirflowRouter()
Expand Down Expand Up @@ -185,9 +185,13 @@ def ti_update_state(
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)

old = select(TI.state).where(TI.id == ti_id_str).with_for_update()
old = select(TI.state, TI.try_number, TI.max_tries).where(TI.id == ti_id_str).with_for_update()
try:
(previous_state,) = session.execute(old).one()
(
previous_state,
try_number,
max_tries,
) = session.execute(old).one()
except NoResultFound:
log.error("Task Instance %s not found", ti_id_str)
raise HTTPException(
Expand All @@ -199,22 +203,23 @@ def ti_update_state(
)

# We exclude_unset to avoid updating fields that are not set in the payload
# We do not need to deserialize "should_retry" -- it is used for dynamic decision-making within failed state
data = ti_patch_payload.model_dump(exclude_unset=True, exclude={"should_retry"})
data = ti_patch_payload.model_dump(exclude_unset=True)

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
query = query.values(state=ti_patch_payload.state)
updated_state = ti_patch_payload.state
if ti_patch_payload.state == State.FAILED:
# clear the next_method and next_kwargs
query = query.values(next_method=None, next_kwargs=None)
task_instance = session.get(TI, ti_id_str)
if _is_eligible_to_retry(task_instance, ti_patch_payload.should_retry):
query = query.values(state=State.UP_FOR_RETRY)
# if we get failed, we should attempt to retry, as it is a more
# normal state. Tasks with retries are more frequent than without retries.
if ti_patch_payload.state == TerminalTIState.FAIL_WITHOUT_RETRY:
updated_state = State.FAILED
elif ti_patch_payload.state == State.FAILED:
if _is_eligible_to_retry(previous_state, try_number, max_tries):
updated_state = State.UP_FOR_RETRY
else:
updated_state = State.FAILED
query = query.values(state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
Expand Down Expand Up @@ -366,22 +371,13 @@ def ti_put_rtif(
return {"message": "Rendered task instance fields successfully set"}


def _is_eligible_to_retry(task_instance: TI, should_retry: bool) -> bool:
"""
Is task instance is eligible for retry.
:param task_instance: the task instance
:meta private:
"""
if not should_retry:
return False

if task_instance.state == State.RESTARTING:
def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True

# max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for
# retries from the task SDK now, we can handle using max_tries
return task_instance.max_tries and task_instance.try_number <= task_instance.max_tries
return max_tries != 0 and try_number <= max_tries
5 changes: 4 additions & 1 deletion airflow/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,12 @@ class TerminalTIState(str, Enum):
"""States that a Task Instance can be in that indicate it has reached a terminal state."""

SUCCESS = "success"
FAILED = "failed"
FAILED = "failed" # This state indicates that we attempt to retry.
SKIPPED = "skipped" # A user can raise a AirflowSkipException from a task & it will be marked as skipped
REMOVED = "removed"
FAIL_WITHOUT_RETRY = (
"fail_without_retry" # This state is useful for when we want to terminate a task, without retrying.
)

def __str__(self) -> str:
return self.value
Expand Down
5 changes: 0 additions & 5 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,6 @@ def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime):
body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state))
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def fail(self, id: uuid.UUID, when: datetime, should_retry: bool):
"""Tell the API server that this TI has to fail, with or without retries."""
body = TITerminalStatePayload(end_date=when, state=TerminalTIState.FAILED, should_retry=should_retry)
self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json())

def heartbeat(self, id: uuid.UUID, pid: int):
body = TIHeartbeatInfo(pid=pid, hostname=get_hostname())
self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json())
Expand Down
6 changes: 4 additions & 2 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,12 @@ class TerminalTIState(str, Enum):
"""

SUCCESS = "success"
FAILED = "failed"
FAILED = "failed" # This state indicates that we attempt to retry.
SKIPPED = "skipped"
REMOVED = "removed"
FAIL_WITHOUT_RETRY = (
"fail_without_retry" # This state is useful for when we want to terminate a task, without retrying.
)


class ValidationError(BaseModel):
Expand Down Expand Up @@ -217,4 +220,3 @@ class TITerminalStatePayload(BaseModel):

state: TerminalTIState
end_date: Annotated[datetime, Field(title="End Date")]
should_retry: Annotated[bool | None, Field(title="Should Retry")] = False
17 changes: 1 addition & 16 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,25 +145,11 @@ class TaskState(BaseModel):
- anything else = FAILED
"""

state: Literal[TerminalTIState.SUCCESS, TerminalTIState.REMOVED, TerminalTIState.SKIPPED]
state: TerminalTIState
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"


class FailState(BaseModel):
"""
Update a task's state to FAILED.
Contains attributes specific to FAILING a state like
ability to retry.
"""

should_retry: bool = True
end_date: datetime | None = None
state: Literal[TerminalTIState.FAILED] = TerminalTIState.FAILED
type: Literal["FailState"] = "FailState"


class DeferTask(TIDeferredStatePayload):
"""Update a task instance state to deferred."""

Expand Down Expand Up @@ -246,7 +232,6 @@ class SetRenderedFields(BaseModel):
ToSupervisor = Annotated[
Union[
TaskState,
FailState,
GetXCom,
GetConnection,
GetVariable,
Expand Down
15 changes: 2 additions & 13 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
FailState,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -294,9 +293,6 @@ class WatchedSubprocess:
_exit_code: int | None = attrs.field(default=None, init=False)
_terminal_state: str | None = attrs.field(default=None, init=False)
_final_state: str | None = attrs.field(default=None, init=False)
# denotes if a request to `fail` has been sent from the _handle_requests or not, or it will be handled in wait()
# useful to synchronise the API requests for `fail` between handle_requests and wait
_fail_request_sent: bool = attrs.field(default=False, init=False)

_last_successful_heartbeat: float = attrs.field(default=0, init=False)
_last_heartbeat_attempt: float = attrs.field(default=0, init=False)
Expand Down Expand Up @@ -525,7 +521,7 @@ def wait(self) -> int:
# to reflect the final state of the process.
# For states like `deferred`, the process will exit with 0, but the state will be updated
# by the subprocess in the `handle_requests` method.
if (not self._fail_request_sent) and self.final_state in TerminalTIState:
if self.final_state in TerminalTIState:
self.client.task_instances.finish(
id=self.id, state=self.final_state, when=datetime.now(tz=timezone.utc)
)
Expand Down Expand Up @@ -714,14 +710,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
log.debug("Received message from task runner", msg=msg)
resp = None
if isinstance(msg, FailState):
self._terminal_state = TerminalTIState.FAILED
self._task_end_time_monotonic = time.monotonic()
self._fail_request_sent = True
log.debug("IN SIDE FAILSTATE.")
self.client.task_instances.fail(self.id, datetime.now(tz=timezone.utc), msg.should_retry)
elif isinstance(msg, TaskState):
log.debug("IN SIDE TaskState.")
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, GetConnection):
Expand Down
13 changes: 5 additions & 8 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.execution_time.comms import (
DeferTask,
FailState,
GetXCom,
RescheduleTask,
SetRenderedFields,
Expand Down Expand Up @@ -409,10 +408,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):

# TODO: Handle fail_stop here: https://github.com/apache/airflow/issues/44951
# TODO: Handle addition to Log table: https://github.com/apache/airflow/issues/44952
msg = FailState(
state=TerminalTIState.FAILED,
msg = TaskState(
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
should_retry=False,
)

# TODO: Run task failure callbacks here
Expand All @@ -423,16 +421,15 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# External state updates are already handled with `ti_heartbeat` and will be
# updated already be another UI API. So, these exceptions should ideally never be thrown.
# If these are thrown, we should mark the TI state as failed.
msg = FailState(
state=TerminalTIState.FAILED,
msg = TaskState(
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=datetime.now(tz=timezone.utc),
should_retry=False,
)
# TODO: Run task failure callbacks here
except SystemExit:
...
except BaseException:
msg = FailState(should_retry=True, end_date=datetime.now(tz=timezone.utc))
msg = TaskState(state=TerminalTIState.FAILED, end_date=datetime.now(tz=timezone.utc))
if msg:
SUPERVISOR_COMMS.send_request(msg=msg, log=log)

Expand Down
20 changes: 1 addition & 19 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def handle_request(request: httpx.Request) -> httpx.Response:
resp = client.task_instances.start(ti_id, 100, start_date)
assert resp == ti_context

@pytest.mark.parametrize("state", [state for state in TerminalTIState if state != TerminalTIState.FAILED])
@pytest.mark.parametrize("state", [state for state in TerminalTIState])
def test_task_instance_finish(self, state):
# Simulate a successful response from the server that finishes (moved to terminal state) a task
ti_id = uuid6.uuid7()
Expand All @@ -139,24 +139,6 @@ def handle_request(request: httpx.Request) -> httpx.Response:
client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.finish(ti_id, state=state, when="2024-10-31T12:00:00Z")

def test_task_instance_fail(self):
# Simulate a successful response from the server that fails a task with retry.
ti_id = uuid6.uuid7()

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == f"/task-instances/{ti_id}/state":
actual_body = json.loads(request.read())
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
assert actual_body["state"] == TerminalTIState.FAILED
assert actual_body["should_retry"] is True
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
client.task_instances.fail(ti_id, when="2024-10-31T12:00:00Z", should_retry=True)

def test_task_instance_heartbeat(self):
# Simulate a successful response from the server that sends a heartbeat for a ti
ti_id = uuid6.uuid7()
Expand Down
14 changes: 6 additions & 8 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
FailState,
GetConnection,
GetVariable,
GetXCom,
Expand Down Expand Up @@ -885,18 +884,17 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_skipped",
),
# testing to see if supervisor can handle FailState message
# testing to see if supervisor can handle TaskState message with state as fail_with_retry
pytest.param(
FailState(
state=TerminalTIState.FAILED,
TaskState(
state=TerminalTIState.FAIL_WITHOUT_RETRY,
end_date=timezone.parse("2024-10-31T12:00:00Z"),
should_retry=False,
),
b"",
"task_instances.fail",
(TI_ID, timezone.parse("2024-11-7T12:00:00Z"), False),
"",
id="patch_task_instance_to_failed",
(),
"",
id="patch_task_instance_to_failed_with_retries",
),
pytest.param(
SetRenderedFields(rendered_fields={"field1": "rendered_value1", "field2": "rendered_value2"}),
Expand Down
16 changes: 4 additions & 12 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow.sdk.execution_time.comms import (
ConnectionResult,
DeferTask,
FailState,
GetConnection,
SetRenderedFields,
StartupDetails,
Expand Down Expand Up @@ -257,20 +256,14 @@ def test_run_basic_skipped(time_machine, mocked_parse, make_ti_context):
)


@pytest.mark.parametrize(
"retries",
[None, 0, 3],
)
def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context, retries):
"""Test running a basic task that raises a base exception."""
def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context):
"""Test running a basic task that raises a base exception which should send fail_with_retry state."""
from airflow.providers.standard.operators.python import PythonOperator

task = PythonOperator(
task_id="zero_division_error",
python_callable=lambda: 1 / 0,
)
if retries is not None:
task.retries = retries

what = StartupDetails(
ti=TaskInstance(
Expand All @@ -296,8 +289,7 @@ def test_run_raises_base_exception(time_machine, mocked_parse, make_ti_context,
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=FailState(
should_retry=True,
msg=TaskState(
state=TerminalTIState.FAILED,
end_date=instant,
),
Expand Down Expand Up @@ -469,7 +461,7 @@ def execute(self, context):
run(ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_called_once_with(
msg=FailState(state=TerminalTIState.FAILED, end_date=instant, should_retry=False), log=mock.ANY
msg=TaskState(state=TerminalTIState.FAIL_WITHOUT_RETRY, end_date=instant), log=mock.ANY
)


Expand Down
Loading

0 comments on commit 70c7c73

Please sign in to comment.