Skip to content

Commit

Permalink
Respect soft_fail argument when running SqsSensor (apache#34569)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 authored Sep 25, 2023
1 parent 84f70da commit 2b5c767
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
14 changes: 11 additions & 3 deletions airflow/providers/amazon/aws/sensors/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from typing_extensions import Literal

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils.sqs import process_response
Expand Down Expand Up @@ -145,7 +145,11 @@ def execute(self, context: Context) -> Any:

def execute_complete(self, context: Context, event: dict | None = None) -> None:
if event is None or event["status"] != "success":
raise AirflowException(f"Trigger error: event is {event}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Trigger error: event is {event}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
context["ti"].xcom_push(key="messages", value=event["message_batch"])

def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
Expand Down Expand Up @@ -203,7 +207,11 @@ def poke(self, context: Context):
response = self.hook.conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)

if "Successful" not in response:
raise AirflowException(f"Delete SQS Messages failed {response} for messages {messages}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
error_message = f"Delete SQS Messages failed {response} for messages {messages}"
if self.soft_fail:
raise AirflowSkipException(error_message)
raise AirflowException(error_message)
if message_batch:
context["ti"].xcom_push(key="messages", value=message_batch)
return True
Expand Down
48 changes: 47 additions & 1 deletion tests/providers/amazon/aws/sensors/test_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import pytest
from moto import mock_sqs

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
Expand Down Expand Up @@ -346,3 +346,49 @@ def test_sqs_deferrable(self):
)
with pytest.raises(TaskDeferred):
self.sensor.execute(None)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
def test_fail_execute_complete(self, soft_fail, expected_exception):
self.sensor = SqsSensor(
task_id="test_task_deferrable",
dag=self.dag,
sqs_queue=QUEUE_URL,
aws_conn_id="aws_default",
max_messages=1,
num_batches=3,
deferrable=True,
soft_fail=soft_fail,
)
event = {"status": "failed"}
message = f"Trigger error: event is {event}"
with pytest.raises(expected_exception, match=message):
self.sensor.execute_complete(context={}, event=event)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@mock.patch("airflow.providers.amazon.aws.sensors.sqs.SqsSensor.poll_sqs")
@mock.patch("airflow.providers.amazon.aws.sensors.sqs.process_response")
@mock.patch("airflow.providers.amazon.aws.hooks.sqs.SqsHook.conn")
def test_fail_poke(self, conn, process_response, poll_sqs, soft_fail, expected_exception):
self.sensor = SqsSensor(
task_id="test_task_deferrable",
dag=self.dag,
sqs_queue=QUEUE_URL,
aws_conn_id="aws_default",
max_messages=1,
num_batches=3,
deferrable=True,
soft_fail=soft_fail,
)
response = "error message"
messages = [{"MessageId": "1", "ReceiptHandle": "test"}]
poll_sqs.return_value = response
process_response.return_value = messages
conn.delete_message_batch.return_value = response
error_message = f"Delete SQS Messages failed {response} for messages"
self.sensor.delete_message_on_reception = True
with pytest.raises(expected_exception, match=error_message):
self.sensor.poke(context={})

0 comments on commit 2b5c767

Please sign in to comment.