Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HTTP retry handling into task SDK api.client #45121

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions task_sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"msgspec>=0.18.6",
"psutil>=6.1.0",
"structlog>=24.4.0",
"retryhttp>=1.2.0",
]
classifiers = [
"Framework :: Apache Airflow",
Expand Down
27 changes: 27 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

import logging
import os
import sys
import uuid
from http import HTTPStatus
Expand All @@ -26,6 +28,8 @@
import msgspec
import structlog
from pydantic import BaseModel
from retryhttp import retry, wait_retry_after
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7

from airflow.sdk import __version__
Expand Down Expand Up @@ -263,6 +267,14 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"text": "Hello, world!"})


# Config options for SDK how retries on HTTP requests should be handled
# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min
# As long as there is no other config facility in SDK we use ENV for the moment
API_RETRIES = int(os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))
API_RETRY_WAIT_MIN = int(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 1))
API_RETRY_WAIT_MAX = int(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 90))

Copy link
Contributor

@shubhamraj-git shubhamraj-git Dec 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things here:

  1. Currently, the API_RETRIES, API_RETRY_WAIT_MIN, and API_RETRY_WAIT_MAX are directly cast to integers. This can raise a ValueError if the environment variables are not set correctly. Can we add validation or some fallback defaults (In case of error, if we don't want to fail)?
  2. Let's also have a check, API_RETRY_WAIT_MIN < API_RETRY_WAIT_MAX and raise alarm if not ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good points. But as in the comments above it is referring to a "proper config mechanism" which is missing. Would leave it up there.
The values at the end are passed to tenacity, even there in the logic it seems to be no validation. If you "mess-up" the values then it will use the max as minimum it seems. Not too bad.

For me it is okay like this, except if @kaxil or @ashb would also enforce me to make it bullet-proof here. My aim was that it is working by default but that the defaults can be over-ridden. No "official" config but something that can be tweaked if you are looking into code.


class Client(httpx.Client):
def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
if (not base_url) ^ dry_run:
Expand All @@ -284,6 +296,21 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
**kwargs,
)

_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX)

@retry(
reraise=True,
max_attempt_number=API_RETRIES,
wait_server_errors=_default_wait,
wait_network_errors=_default_wait,
wait_timeouts=_default_wait,
wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429
before_sleep=before_log(log, logging.WARNING),
)
def request(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request with a retry layer."""
return super().request(*args, **kwargs)

# We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
# methods on one object prefixed with the object type (`.task_instances.update` rather than
# `task_instance_update` etc.)
Expand Down
124 changes: 94 additions & 30 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
from unittest import mock

import httpx
import pytest
Expand All @@ -30,18 +31,28 @@
from airflow.utils.state import TerminalTIState


class TestClient:
def test_error_parsing(self):
def handle_request(request: httpx.Request) -> httpx.Response:
"""
A transport handle that always returns errors
"""
def make_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
return Client(base_url="test://server", token="", transport=transport)

return httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})

client = Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)
def make_client_w_responses(responses: list[httpx.Response]) -> Client:
"""Helper fixture to create a mock client with custom responses."""

def handle_request(request: httpx.Request) -> httpx.Response:
return responses.pop(0)

return Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)


class TestClient:
def test_error_parsing(self):
responses = [
httpx.Response(422, json={"detail": [{"loc": ["#0"], "msg": "err", "type": "required"}]})
]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
Expand All @@ -53,39 +64,92 @@ def handle_request(request: httpx.Request) -> httpx.Response:
]

def test_error_parsing_plain_text(self):
def handle_request(request: httpx.Request) -> httpx.Response:
"""
A transport handle that always returns errors
"""

return httpx.Response(422, content=b"Internal Server Error")

client = Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)
responses = [httpx.Response(422, content=b"Internal Server Error")]
client = make_client_w_responses(responses)

with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)

def test_error_parsing_other_json(self):
def handle_request(request: httpx.Request) -> httpx.Response:
# Some other json than an error body.
return httpx.Response(404, json={"detail": "Not found"})

client = Client(
base_url=None, dry_run=True, token="", mounts={"'http://": httpx.MockTransport(handle_request)}
)
responses = [httpx.Response(404, json={"detail": "Not found"})]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert err.value.args == ("Not found",)
assert err.value.detail is None

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_unrecoverable_error(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 11,
httpx.Response(200, json={"detail": "Recovered from error - but will fail before"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

jscheffl marked this conversation as resolved.
Show resolved Hide resolved
def make_client(transport: httpx.MockTransport) -> Client:
"""Get a client with a custom transport"""
return Client(base_url="test://server", token="", transport=transport)
with pytest.raises(httpx.HTTPStatusError) as err:
client.get("http://error")
assert not isinstance(err.value, ServerResponseError)
assert len(responses) == 3
assert mock_sleep.call_count == 9

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_recovered(self, mock_sleep):
responses: list[httpx.Response] = [
*[httpx.Response(500, text="Internal Server Error")] * 3,
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 3

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_overload(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(429, text="I am really busy atm, please back-off", headers={"Retry-After": "37"}),
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 1
assert mock_sleep.call_args[0][0] == 37

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_non_retry_error(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(422, json={"detail": "Somehow this is a bad request"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

with pytest.raises(ServerResponseError) as err:
client.get("http://error")
assert len(responses) == 1
assert mock_sleep.call_count == 0
assert err.value.args == ("Somehow this is a bad request",)

@mock.patch("time.sleep", return_value=None)
def test_retry_handling_ok(self, mock_sleep):
responses: list[httpx.Response] = [
httpx.Response(200, json={"detail": "Recovered from error"}),
httpx.Response(400, json={"detail": "Should not get here"}),
]
client = make_client_w_responses(responses)

response = client.get("http://error")
assert response.status_code == 200
assert len(responses) == 1
assert mock_sleep.call_count == 0


class TestTaskInstanceOperations:
Expand Down
2 changes: 2 additions & 0 deletions tests/cli/commands/remote_commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ def test_cli_run_no_local_no_raw_runs_executor(self, dag_maker):
mock.patch(
"airflow.executors.executor_loader.ExecutorLoader.get_default_executor"
) as get_default_mock,
mock.patch("airflow.executors.local_executor.SimpleQueue"), # Prevent a task being queued
mock.patch("airflow.executors.local_executor.LocalExecutor.end"),
):
EmptyOperator(task_id="task1")
EmptyOperator(task_id="task2", executor="foo_executor_alias")
Expand Down