diff --git a/poetry.lock b/poetry.lock index ec97d72..aff46e5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -170,17 +170,6 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope-interface"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -[[package]] -name = "backoff" -version = "2.2.1" -description = "Function decoration for backoff and retry" -optional = false -python-versions = ">=3.7,<4.0" -files = [ - {file = "backoff-2.2.1-py3-none-any.whl", hash = "sha256:63579f9a0628e06278f7e47b7d7d5b6ce20dc65c5e96a6f3ca99a6adca0396e8"}, - {file = "backoff-2.2.1.tar.gz", hash = "sha256:03f829f5bb1923180821643f8753b0502c3b682293992485b0eef2807afa5cba"}, -] - [[package]] name = "bandit" version = "1.7.5" @@ -2614,4 +2603,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f0488524b47d828932195829c4d29ef7e6beb9c30ebb2482a3b5afb97899a902" +content-hash = "f03442986f48e4bcc8c2e4defeb1a3ab90312736e141819dbf98c73b82331970" diff --git a/pyproject.toml b/pyproject.toml index 600a4f6..71db2b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sw-utils" -version = "0.3.13" +version = "0.3.14" description = "StakeWise Python utils" authors = ["StakeWise Labs "] license = "GPL-3.0-or-later" @@ -8,7 +8,6 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" -backoff = "^2.2.1" milagro-bls-binding = "==1.9.0" py-ecc = "^6.0.0" ipfshttpclient = "^0.8.0a2" diff --git a/sw_utils/decorators.py b/sw_utils/decorators.py index 9fa02f1..ac37d86 100644 --- a/sw_utils/decorators.py +++ b/sw_utils/decorators.py @@ -1,148 +1,42 @@ import asyncio -import functools import logging +import typing import aiohttp -import backoff -import requests +from tenacity import retry, retry_if_exception, stop_after_delay, wait_exponential -from sw_utils.exceptions import ( - AiohttpRecoveredErrors, - RecoverableServerError, - RequestsRecoveredErrors, -) +default_logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +if typing.TYPE_CHECKING: + from tenacity import RetryCallState -def wrap_aiohttp_500_errors(f): - """ - Allows to distinguish between HTTP 400 and HTTP 500 errors. - Both are represented by `aiohttp.ClientResponseError`. - """ - @functools.wraps(f) - async def wrapper(*args, **kwargs): - try: - return await f(*args, **kwargs) - except aiohttp.ClientResponseError as e: - if e.status >= 500: - raise RecoverableServerError(e) from e - raise +def custom_before_log(logger, log_level): + def custom_log_it(retry_state: 'RetryCallState') -> None: + if retry_state.attempt_number <= 1: + return + msg = 'Retrying %s(), attempt %s' + args = (retry_state.fn.__name__, retry_state.attempt_number) # type: ignore + logger.log(log_level, msg, *args) - return wrapper + return custom_log_it -def backoff_aiohttp_errors(max_tries: int | None = None, max_time: int | None = None, **kwargs): - """ - Can be used for: - * retrying web3 api calls - * retrying aiohttp calls to services +def can_be_retried_aiohttp_error(e: BaseException) -> bool: + if isinstance(e, (asyncio.TimeoutError, aiohttp.ClientConnectionError)): + return True - DO NOT use `backoff_aiohttp_errors` for handling errors in IpfsFetchClient - or IpfsMultiUploadClient. - Catch `sw_utils/ipfs.py#IpfsException` instead. + if isinstance(e, aiohttp.ClientResponseError) and e.status >= 500: + return True - Retry: - * connection errors - * HTTP 500 errors - Do not retry: - * HTTP 400 errors - * regular Python errors - """ + return False - backoff_decorator = backoff.on_exception( - backoff.expo, - AiohttpRecoveredErrors, - max_tries=max_tries, - max_time=max_time, - **kwargs, - ) - - def decorator(f): - @functools.wraps(f) - async def wrapper(*args, **kwargs): - try: - return await backoff_decorator(wrap_aiohttp_500_errors(f))(*args, **kwargs) - except RecoverableServerError as e: - raise e.origin - - return wrapper - - return decorator - - -def wrap_requests_500_errors(f): - """ - Allows to distinguish between HTTP 400 and HTTP 500 errors. - Both are represented by `requests.HTTPError`. - """ - if asyncio.iscoroutinefunction(f): - - @functools.wraps(f) - async def async_wrapper(*args, **kwargs): - try: - return await f(*args, **kwargs) - except requests.HTTPError as e: - if e.response.status >= 500: - raise RecoverableServerError(e) from e - raise - - return async_wrapper - - @functools.wraps(f) - def wrapper(*args, **kwargs): - try: - return f(*args, **kwargs) - except requests.HTTPError as e: - if e.response.status >= 500: - raise RecoverableServerError(e) from e - raise - - return wrapper - -def backoff_requests_errors(max_tries: int | None = None, max_time: int | None = None, **kwargs): - """ - DO NOT use `backoff_requests_errors` for handling errors in IpfsFetchClient - or IpfsMultiUploadClient. - Catch `sw_utils/ipfs.py#IpfsException` instead. - - Retry: - * connection errors - * HTTP 500 errors - Do not retry: - * HTTP 400 errors - * regular Python errors - """ - - backoff_decorator = backoff.on_exception( - backoff.expo, - RequestsRecoveredErrors, - max_tries=max_tries, - max_time=max_time, - **kwargs, +def retry_aiohttp_errors(delay: int = 60): + return retry( + retry=retry_if_exception(can_be_retried_aiohttp_error), + wait=wait_exponential(multiplier=1, min=1, max=delay // 2), + stop=stop_after_delay(delay), + before=custom_before_log(default_logger, logging.INFO), ) - - def decorator(f): - if asyncio.iscoroutinefunction(f): - - @functools.wraps(f) - async def async_wrapper(*args, **kwargs): - try: - return await backoff_decorator(wrap_requests_500_errors(f))(*args, **kwargs) - except RecoverableServerError as e: - raise e.origin - - return async_wrapper - - @functools.wraps(f) - def wrapper(*args, **kwargs): - try: - return backoff_decorator(wrap_requests_500_errors(f))(*args, **kwargs) - except RecoverableServerError as e: - raise e.origin - - return wrapper - - return decorator diff --git a/sw_utils/tenacity_decorators.py b/sw_utils/tenacity_decorators.py deleted file mode 100644 index ac37d86..0000000 --- a/sw_utils/tenacity_decorators.py +++ /dev/null @@ -1,42 +0,0 @@ -import asyncio -import logging -import typing - -import aiohttp -from tenacity import retry, retry_if_exception, stop_after_delay, wait_exponential - -default_logger = logging.getLogger(__name__) - - -if typing.TYPE_CHECKING: - from tenacity import RetryCallState - - -def custom_before_log(logger, log_level): - def custom_log_it(retry_state: 'RetryCallState') -> None: - if retry_state.attempt_number <= 1: - return - msg = 'Retrying %s(), attempt %s' - args = (retry_state.fn.__name__, retry_state.attempt_number) # type: ignore - logger.log(log_level, msg, *args) - - return custom_log_it - - -def can_be_retried_aiohttp_error(e: BaseException) -> bool: - if isinstance(e, (asyncio.TimeoutError, aiohttp.ClientConnectionError)): - return True - - if isinstance(e, aiohttp.ClientResponseError) and e.status >= 500: - return True - - return False - - -def retry_aiohttp_errors(delay: int = 60): - return retry( - retry=retry_if_exception(can_be_retried_aiohttp_error), - wait=wait_exponential(multiplier=1, min=1, max=delay // 2), - stop=stop_after_delay(delay), - before=custom_before_log(default_logger, logging.INFO), - ) diff --git a/sw_utils/tests/test_decorators.py b/sw_utils/tests/test_decorators.py deleted file mode 100644 index 33ca65c..0000000 --- a/sw_utils/tests/test_decorators.py +++ /dev/null @@ -1,213 +0,0 @@ -import asyncio -from unittest.mock import Mock - -import aiohttp -import pytest -import requests - -from sw_utils.decorators import backoff_aiohttp_errors, backoff_requests_errors - - -class TestBackoffAiohttpErrors: - async def test_bad_request_http_error(self): - call_count = 0 - - @backoff_aiohttp_errors(max_tries=2, max_time=2) - async def raise_bad_request_http_error(): - nonlocal call_count - call_count += 1 - - # simulate aiohttp.ClientResponse.raise_for_status - raise aiohttp.ClientResponseError( - Mock(), - (Mock(),), - status=400, - message='', - headers={}, - ) - - with pytest.raises(aiohttp.ClientResponseError): - await raise_bad_request_http_error() - - assert call_count == 1 - - async def test_500_http_error(self): - call_count = 0 - - @backoff_aiohttp_errors(max_tries=2, max_time=2) - async def raise_500_http_error(): - nonlocal call_count - call_count += 1 - - # simulate aiohttp.ClientResponse.raise_for_status - raise aiohttp.ClientResponseError( - Mock(), - (Mock(),), - status=500, - message='', - headers={}, - ) - - with pytest.raises(aiohttp.ClientResponseError): - await raise_500_http_error() - - assert call_count == 2 - - async def test_recover_500_http_error(self): - call_count = 0 - - @backoff_aiohttp_errors(max_tries=2, max_time=2) - async def recover_500_http_error(): - nonlocal call_count - call_count += 1 - - if call_count == 1: - # simulate aiohttp.ClientResponse.raise_for_status - raise aiohttp.ClientResponseError( - Mock(), - (Mock(),), - status=500, - message='', - headers={}, - ) - - return 'Recovered after 500 error' - - await recover_500_http_error() - - assert call_count == 2 - - async def test_server_timeout_error(self): - call_count = 0 - - @backoff_aiohttp_errors(max_tries=2, max_time=1) - async def raise_timeout_error(): - nonlocal call_count - call_count += 1 - - raise aiohttp.ServerTimeoutError - - with pytest.raises(aiohttp.ServerTimeoutError): - await raise_timeout_error() - - assert call_count == 2 - - async def test_recover_asyncio_timeout_error(self): - call_count = 0 - - @backoff_aiohttp_errors(max_tries=2, max_time=1) - async def recover_asyncio_timeout_error(): - nonlocal call_count - call_count += 1 - - if call_count == 1: - raise asyncio.TimeoutError - - return 'recovered after asyncio.TimeoutError' - - await recover_asyncio_timeout_error() - - assert call_count == 2 - - -class TestBackoffRequestsErrors: - def test_bad_request_http_error(self): - call_count = 0 - - @backoff_requests_errors(max_tries=2, max_time=2) - def raise_bad_request_http_error(): - nonlocal call_count - call_count += 1 - - # simulate requests.Response.raise_for_status() - response_mock = Mock() - response_mock.status = 400 - raise requests.HTTPError('400 client error', response=response_mock) - - with pytest.raises(requests.HTTPError): - raise_bad_request_http_error() - - assert call_count == 1 - - def test_500_http_error(self): - call_count = 0 - - @backoff_requests_errors(max_tries=2, max_time=2) - def raise_500_http_error(): - nonlocal call_count - call_count += 1 - - # simulate requests.Response.raise_for_status() - response_mock = Mock() - response_mock.status = 500 - raise requests.HTTPError('500 server error', response=response_mock) - - with pytest.raises(requests.HTTPError): - raise_500_http_error() - - assert call_count == 2 - - def test_recover_500_http_error(self): - call_count = 0 - - @backoff_requests_errors(max_tries=2, max_time=2) - def recover_500_http_error(): - nonlocal call_count - call_count += 1 - - if call_count == 1: - # simulate requests.Response.raise_for_status() - response_mock = Mock() - response_mock.status = 500 - raise requests.HTTPError('500 server error', response=response_mock) - - return 'Recovered after 500 error' - - recover_500_http_error() - - assert call_count == 2 - - def test_connect_timeout_error(self): - call_count = 0 - - @backoff_requests_errors(max_tries=2, max_time=1) - def raise_timeout_error(): - nonlocal call_count - call_count += 1 - - raise requests.ConnectTimeout - - with pytest.raises(requests.ConnectTimeout): - raise_timeout_error() - - assert call_count == 2 - - def test_read_timeout_error(self): - call_count = 0 - - @backoff_requests_errors(max_tries=2, max_time=1) - def raise_timeout_error(): - nonlocal call_count - call_count += 1 - - raise requests.ReadTimeout - - with pytest.raises(requests.ReadTimeout): - raise_timeout_error() - - assert call_count == 2 - - async def test_async_connect_timeout_error(self): - call_count = 0 - - @backoff_requests_errors(max_tries=2, max_time=1) - async def raise_timeout_error(): - nonlocal call_count - call_count += 1 - - raise requests.ConnectTimeout - - with pytest.raises(requests.ConnectTimeout): - await raise_timeout_error() - - assert call_count == 2