Skip to content

Commit

Permalink
Respect soft_fail argument when running BatchSensors (apache#34592)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 authored Sep 25, 2023
1 parent 2035dc7 commit 5a133e8
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 8 deletions.
37 changes: 29 additions & 8 deletions airflow/providers/amazon/aws/sensors/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,17 @@ def poke(self, context: Context) -> bool:
return False

if state == BatchClientHook.FAILURE_STATE:
raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Batch sensor failed. AWS Batch job status: {state}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"Batch sensor failed. Unknown AWS Batch job status: {state}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

def execute(self, context: Context) -> None:
if not self.deferrable:
Expand Down Expand Up @@ -182,7 +190,11 @@ def poke(self, context: Context) -> bool:
)

if not response["computeEnvironments"]:
raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found")
message = f"AWS Batch compute environment {self.compute_environment} not found"
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

status = response["computeEnvironments"][0]["status"]

Expand All @@ -192,9 +204,11 @@ def poke(self, context: Context) -> bool:
if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS:
return False

raise AirflowException(
f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
)
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)


class BatchJobQueueSensor(BaseSensorOperator):
Expand Down Expand Up @@ -250,7 +264,11 @@ def poke(self, context: Context) -> bool:
if self.treat_non_existing_as_deleted:
return True
else:
raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
# TODO: remove this if block when min_airflow_version is set to higher than 2.7.1
message = f"AWS Batch job queue {self.job_queue} not found"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)

status = response["jobQueues"][0]["status"]

Expand All @@ -260,4 +278,7 @@ def poke(self, context: Context) -> bool:
if status in BatchClientHook.JOB_QUEUE_INTERMEDIATE_STATUS:
return False

raise AirflowException(f"AWS Batch job queue failed. AWS Batch job queue status: {status}")
message = f"AWS Batch job queue failed. AWS Batch job queue status: {status}"
if self.soft_fail:
raise AirflowSkipException(message)
raise AirflowException(message)
76 changes: 76 additions & 0 deletions tests/providers/amazon/aws/sensors/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,34 @@ def test_execute_failure_in_deferrable_mode_with_soft_fail(self, deferrable_batc
with pytest.raises(AirflowSkipException):
deferrable_batch_sensor.execute_complete(context={}, event={"status": "failure"})

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@pytest.mark.parametrize(
"state, error_message",
(
(
BatchClientHook.FAILURE_STATE,
f"Batch sensor failed. AWS Batch job status: {BatchClientHook.FAILURE_STATE}",
),
("unknown_state", "Batch sensor failed. Unknown AWS Batch job status: unknown_state"),
),
)
@mock.patch.object(BatchClientHook, "get_job_description")
def test_fail_poke(
self,
mock_get_job_description,
batch_sensor: BatchSensor,
state,
error_message,
soft_fail,
expected_exception,
):
mock_get_job_description.return_value = {"status": state}
batch_sensor.soft_fail = soft_fail
with pytest.raises(expected_exception, match=error_message):
batch_sensor.poke({})


@pytest.fixture(scope="module")
def batch_compute_environment_sensor() -> BatchComputeEnvironmentSensor:
Expand Down Expand Up @@ -174,6 +202,34 @@ def test_poke_invalid(
)
assert "AWS Batch compute environment failed" in str(ctx.value)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@pytest.mark.parametrize(
"compute_env, error_message",
(
(
[{"status": "unknown_status"}],
"AWS Batch compute environment failed. AWS Batch compute environment status:",
),
([], "AWS Batch compute environment"),
),
)
@mock.patch.object(BatchClientHook, "client")
def test_fail_poke(
self,
mock_batch_client,
batch_compute_environment_sensor: BatchComputeEnvironmentSensor,
compute_env,
error_message,
soft_fail,
expected_exception,
):
mock_batch_client.describe_compute_environments.return_value = {"computeEnvironments": compute_env}
batch_compute_environment_sensor.soft_fail = soft_fail
with pytest.raises(expected_exception, match=error_message):
batch_compute_environment_sensor.poke({})


@pytest.fixture(scope="module")
def batch_job_queue_sensor() -> BatchJobQueueSensor:
Expand Down Expand Up @@ -242,3 +298,23 @@ def test_poke_invalid(self, mock_batch_client, batch_job_queue_sensor: BatchJobQ
jobQueues=[JOB_QUEUE],
)
assert "AWS Batch job queue failed" in str(ctx.value)

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
)
@pytest.mark.parametrize("job_queue", ([], [{"status": "UNKNOWN_STATUS"}]))
@mock.patch.object(BatchClientHook, "client")
def test_fail_poke(
self,
mock_batch_client,
batch_job_queue_sensor: BatchJobQueueSensor,
job_queue,
soft_fail,
expected_exception,
):
mock_batch_client.describe_job_queues.return_value = {"jobQueues": job_queue}
batch_job_queue_sensor.treat_non_existing_as_deleted = False
batch_job_queue_sensor.soft_fail = soft_fail
message = "AWS Batch job queue"
with pytest.raises(expected_exception, match=message):
batch_job_queue_sensor.poke({})

0 comments on commit 5a133e8

Please sign in to comment.