Skip to content

Commit

Permalink
fix(providers/amazon): respect soft_fail argument when exception is r…
Browse files Browse the repository at this point in the history
…aised (apache#34134)
  • Loading branch information
Lee-W authored Sep 25, 2023
1 parent 5ac1e7c commit a4ecdc9
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 30 deletions.
15 changes: 13 additions & 2 deletions airflow/providers/amazon/aws/sensors/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowSkipException
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -57,7 +58,12 @@ def poke(self, context: Context):
return True
if stack_status in ("CREATE_IN_PROGRESS", None):
return False
raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}")

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Stack {self.stack_name} in bad state: {stack_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise ValueError(message)

@cached_property
def hook(self) -> CloudFormationHook:
Expand Down Expand Up @@ -101,7 +107,12 @@ def poke(self, context: Context):
return True
if stack_status == "DELETE_IN_PROGRESS":
return False
raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}")

# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Stack {self.stack_name} in bad state: {stack_status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise ValueError(message)

@cached_property
def hook(self) -> CloudFormationHook:
Expand Down
16 changes: 11 additions & 5 deletions airflow/providers/amazon/aws/sensors/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from deprecated import deprecated

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.dms import DmsHook
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -75,17 +75,23 @@ def poke(self, context: Context):
status: str | None = self.hook.get_task_status(self.replication_task_arn)

if not status:
raise AirflowException(
f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
)
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

self.log.info("DMS Replication task (%s) has status: %s", self.replication_task_arn, status)

if status in self.target_statuses:
return True

if status in self.termination_statuses:
raise AirflowException(f"Unexpected status: {status}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Unexpected status: {status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

return False

Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/amazon/aws/sensors/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.providers.amazon.aws.triggers.ec2 import EC2StateSensorTrigger
from airflow.sensors.base import BaseSensorOperator
Expand Down Expand Up @@ -94,5 +94,9 @@ def poke(self, context: Context):

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error: {event}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
return
13 changes: 6 additions & 7 deletions airflow/providers/amazon/aws/sensors/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.eks import (
ClusterStates,
EksHook,
Expand Down Expand Up @@ -53,9 +53,6 @@
NodegroupStates.NONEXISTENT,
}
)
UNEXPECTED_TERMINAL_STATE_MSG = (
"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
)


class EksBaseSensor(BaseSensorOperator):
Expand Down Expand Up @@ -109,9 +106,11 @@ def poke(self, context: Context) -> bool:
self.log.info("Current state: %s", state)
if state in (self.get_terminal_states() - {self.target_state}):
# If we reach a terminal state which is not the target state:
raise AirflowException(
UNEXPECTED_TERMINAL_STATE_MSG.format(current_state=state, target_state=self.target_state)
)
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Terminal state reached. Current state: {state}, Expected state: {self.target_state}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
return state == self.target_state

@abstractmethod
Expand Down
40 changes: 33 additions & 7 deletions airflow/providers/amazon/aws/sensors/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from deprecated import deprecated

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.links.emr import EmrClusterLink, EmrLogsLink, get_log_uri
from airflow.providers.amazon.aws.triggers.emr import (
Expand Down Expand Up @@ -82,7 +82,11 @@ def poke(self, context: Context):
return True

if state in self.failed_states:
raise AirflowException(f"EMR job failed: {self.failure_message_from_response(response)}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"EMR job failed: {self.failure_message_from_response(response)}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

return False

Expand Down Expand Up @@ -156,6 +160,9 @@ def poke(self, context: Context) -> bool:

if state in EmrServerlessHook.JOB_FAILURE_STATES:
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(failure_message)
raise AirflowException(failure_message)

return state in self.target_states
Expand Down Expand Up @@ -210,7 +217,10 @@ def poke(self, context: Context) -> bool:
state = response["application"]["state"]

if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
if self.soft_fail:
raise AirflowSkipException(failure_message)
raise AirflowException(failure_message)

return state in self.target_states
Expand Down Expand Up @@ -295,7 +305,11 @@ def poke(self, context: Context) -> bool:
)

if state in self.FAILURE_STATES:
raise AirflowException("EMR Containers sensor failed")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = "EMR Containers sensor failed"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

if state in self.INTERMEDIATE_STATES:
return False
Expand Down Expand Up @@ -323,7 +337,11 @@ def execute(self, context: Context):

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error while running job: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
else:
self.log.info(event["message"])

Expand Down Expand Up @@ -508,9 +526,13 @@ def execute(self, context: Context) -> None:
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
def execute_complete(self, context: Context, event=None) -> None:
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error while running job: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
self.log.info("Job completed.")


Expand Down Expand Up @@ -637,6 +659,10 @@ def execute(self, context: Context) -> None:

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = f"Error while running job: {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

self.log.info("Job completed.")
16 changes: 9 additions & 7 deletions tests/providers/amazon/aws/sensors/test_eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
CLUSTER_TERMINAL_STATES,
FARGATE_TERMINAL_STATES,
NODEGROUP_TERMINAL_STATES,
UNEXPECTED_TERMINAL_STATE_MSG,
EksClusterStateSensor,
EksFargateProfileStateSensor,
EksNodegroupStateSensor,
Expand Down Expand Up @@ -75,8 +74,9 @@ def test_poke_reached_pending_state(self, mock_get_cluster_state, setUp, pending
def test_poke_reached_unexpected_terminal_state(
self, mock_get_cluster_state, setUp, unexpected_terminal_state
):
expected_message = UNEXPECTED_TERMINAL_STATE_MSG.format(
current_state=unexpected_terminal_state, target_state=self.target_state
expected_message = (
f"Terminal state reached. Current state: {unexpected_terminal_state}, "
f"Expected state: {self.target_state}"
)
mock_get_cluster_state.return_value = unexpected_terminal_state

Expand Down Expand Up @@ -122,8 +122,9 @@ def test_poke_reached_pending_state(self, mock_get_fargate_profile_state, setUp,
def test_poke_reached_unexpected_terminal_state(
self, mock_get_fargate_profile_state, setUp, unexpected_terminal_state
):
expected_message = UNEXPECTED_TERMINAL_STATE_MSG.format(
current_state=unexpected_terminal_state, target_state=self.target_state
expected_message = (
f"Terminal state reached. Current state: {unexpected_terminal_state}, "
f"Expected state: {self.target_state}"
)
mock_get_fargate_profile_state.return_value = unexpected_terminal_state

Expand Down Expand Up @@ -171,8 +172,9 @@ def test_poke_reached_pending_state(self, mock_get_nodegroup_state, setUp, pendi
def test_poke_reached_unexpected_terminal_state(
self, mock_get_nodegroup_state, setUp, unexpected_terminal_state
):
expected_message = UNEXPECTED_TERMINAL_STATE_MSG.format(
current_state=unexpected_terminal_state, target_state=self.target_state
expected_message = (
f"Terminal state reached. Current state: {unexpected_terminal_state}, "
f"Expected state: {self.target_state}"
)
mock_get_nodegroup_state.return_value = unexpected_terminal_state

Expand Down

0 comments on commit a4ecdc9

Please sign in to comment.