Skip to content

Commit

Permalink
Respect soft_fail argument when running EcsBaseSensor (apache#34596)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 authored Sep 25, 2023
1 parent d1b7bca commit 84f70da
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 9 deletions.
18 changes: 10 additions & 8 deletions airflow/providers/amazon/aws/sensors/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
EcsHook,
Expand All @@ -36,11 +36,13 @@
DEFAULT_CONN_ID: str = "aws_default"


def _check_failed(current_state, target_state, failure_states):
def _check_failed(current_state, target_state, failure_states, soft_fail: bool) -> None:
if (current_state != target_state) and (current_state in failure_states):
raise AirflowException(
f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
)
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
if soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)


class EcsBaseSensor(BaseSensorOperator):
Expand Down Expand Up @@ -95,7 +97,7 @@ def poke(self, context: Context):
cluster_state = EcsClusterStates(self.hook.get_cluster_state(cluster_name=self.cluster_name))

self.log.info("Cluster state: %s, waiting for: %s", cluster_state, self.target_state)
_check_failed(cluster_state, self.target_state, self.failure_states)
_check_failed(cluster_state, self.target_state, self.failure_states, self.soft_fail)

return cluster_state == self.target_state

Expand Down Expand Up @@ -141,7 +143,7 @@ def poke(self, context: Context):
)

self.log.info("Task Definition state: %s, waiting for: %s", task_definition_state, self.target_state)
_check_failed(task_definition_state, self.target_state, [self.failure_states])
_check_failed(task_definition_state, self.target_state, [self.failure_states], self.soft_fail)
return task_definition_state == self.target_state


Expand Down Expand Up @@ -181,5 +183,5 @@ def poke(self, context: Context):
task_state = EcsTaskStates(self.hook.get_task_state(cluster=self.cluster, task=self.task))

self.log.info("Task state: %s, waiting for: %s", task_state, self.target_state)
_check_failed(task_state, self.target_state, self.failure_states)
_check_failed(task_state, self.target_state, self.failure_states, self.soft_fail)
return task_state == self.target_state
20 changes: 19 additions & 1 deletion tests/providers/amazon/aws/sensors/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
from slugify import slugify

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.sensors.ecs import (
DEFAULT_CONN_ID,
EcsBaseSensor,
Expand All @@ -35,6 +35,7 @@
EcsTaskDefinitionStateSensor,
EcsTaskStates,
EcsTaskStateSensor,
_check_failed,
)
from airflow.utils import timezone
from airflow.utils.types import NOTSET
Expand Down Expand Up @@ -259,3 +260,20 @@ def test_custom_values_terminal_state(self, failure_states, return_state):
with pytest.raises(AirflowException, match="Terminal state reached"):
task.poke({})
m.assert_called_once_with(cluster=TEST_CLUSTER_NAME, task=TEST_TASK_ARN)


@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
def test_fail__check_failed(soft_fail, expected_exception):
current_state = "FAILED"
target_state = "SUCCESS"
failure_states = ["FAILED"]
message = f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
with pytest.raises(expected_exception, match=message):
_check_failed(
current_state=current_state,
target_state=target_state,
failure_states=failure_states,
soft_fail=soft_fail,
)

0 comments on commit 84f70da

Please sign in to comment.