diff --git a/sw_utils/consensus.py b/sw_utils/consensus.py index 91d2059..8c9d3d1 100644 --- a/sw_utils/consensus.py +++ b/sw_utils/consensus.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any import aiohttp from eth_typing import URI, HexStr @@ -9,6 +9,7 @@ from web3.beacon.api_endpoints import GET_VOLUNTARY_EXITS from sw_utils.common import urljoin +from sw_utils.decorators import retry_aiohttp_errors from sw_utils.exceptions import AiohttpRecoveredErrors logger = logging.getLogger(__name__) @@ -44,6 +45,9 @@ class ValidatorStatus(Enum): ValidatorStatus.WITHDRAWAL_DONE, ] +if TYPE_CHECKING: + from tenacity import RetryCallState + class ExtendedAsyncBeacon(AsyncBeacon): """ @@ -58,11 +62,12 @@ def __init__( base_urls: list[str], timeout: int = 60, session: aiohttp.ClientSession = None, + retry_timeout: int = 0, ) -> None: self.base_urls = base_urls self.timeout = timeout self.session = session - + self.retry_timeout = retry_timeout super().__init__('') # hack origin base_url param async def get_validators_by_ids(self, validator_ids: list[str], state_id: str = 'head') -> dict: @@ -91,6 +96,27 @@ async def submit_voluntary_exit( logger.error('%s: %s', url, repr(error)) async def _async_make_get_request(self, endpoint_uri: str) -> dict[str, Any]: + if self.retry_timeout: + + def custom_before_log(retry_logger, log_level): + def custom_log_it(retry_state: 'RetryCallState') -> None: + if retry_state.attempt_number <= 1: + return + msg = 'Retrying consensus uri %s(), attempt %s' + args = (endpoint_uri, retry_state.attempt_number) + retry_logger.log(log_level, msg, *args) + + return custom_log_it + + retry_decorator = retry_aiohttp_errors( + self.retry_timeout, + log_func=custom_before_log, + ) + return await retry_decorator(self._async_make_get_request_inner)(endpoint_uri) + + return await self._async_make_get_request_inner(endpoint_uri) + + async def _async_make_get_request_inner(self, endpoint_uri: str) -> dict[str, Any]: for i, url in enumerate(self.base_urls): try: uri = URI(urljoin(url, endpoint_uri)) @@ -114,8 +140,16 @@ async def _make_session_get_request(self, uri): data = await response.json() return data + def set_retry_timeout(self, retry_timeout: int): + self.retry_timeout = retry_timeout + def get_consensus_client( - endpoints: list[str], timeout: int = 60, session: aiohttp.ClientSession = None + endpoints: list[str], + timeout: int = 60, + session: aiohttp.ClientSession = None, + retry_timeout: int = 0, ) -> ExtendedAsyncBeacon: - return ExtendedAsyncBeacon(base_urls=endpoints, timeout=timeout, session=session) + return ExtendedAsyncBeacon( + base_urls=endpoints, timeout=timeout, session=session, retry_timeout=retry_timeout + ) diff --git a/sw_utils/decorators.py b/sw_utils/decorators.py index ac37d86..3725ea2 100644 --- a/sw_utils/decorators.py +++ b/sw_utils/decorators.py @@ -16,7 +16,7 @@ def custom_before_log(logger, log_level): def custom_log_it(retry_state: 'RetryCallState') -> None: if retry_state.attempt_number <= 1: return - msg = 'Retrying %s(), attempt %s' + msg = 'Retrying %s, attempt %s' args = (retry_state.fn.__name__, retry_state.attempt_number) # type: ignore logger.log(log_level, msg, *args) @@ -33,10 +33,10 @@ def can_be_retried_aiohttp_error(e: BaseException) -> bool: return False -def retry_aiohttp_errors(delay: int = 60): +def retry_aiohttp_errors(delay: int = 60, log_func=custom_before_log): return retry( retry=retry_if_exception(can_be_retried_aiohttp_error), wait=wait_exponential(multiplier=1, min=1, max=delay // 2), stop=stop_after_delay(delay), - before=custom_before_log(default_logger, logging.INFO), + before=log_func(default_logger, logging.INFO), ) diff --git a/sw_utils/execution.py b/sw_utils/execution.py index 5cfdd14..f72c6e4 100644 --- a/sw_utils/execution.py +++ b/sw_utils/execution.py @@ -1,6 +1,6 @@ import contextlib import logging -from typing import Any +from typing import TYPE_CHECKING, Any from eth_typing import URI from web3 import AsyncWeb3 @@ -10,11 +10,16 @@ from web3.providers.async_rpc import AsyncHTTPProvider from web3.types import RPCEndpoint, RPCResponse +from sw_utils.decorators import retry_aiohttp_errors from sw_utils.exceptions import AiohttpRecoveredErrors logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from tenacity import RetryCallState + + class ProtocolNotSupported(Exception): """Supported protocols: http, https""" @@ -31,9 +36,11 @@ def __init__( self, endpoint_urls: list[str], request_kwargs: Any | None = None, + retry_timeout: int = 0, ): self._endpoint_urls = endpoint_urls self._providers = [] + self.retry_timeout = retry_timeout if endpoint_urls: self.endpoint_uri = URI(endpoint_urls[0]) @@ -48,9 +55,30 @@ def __init__( super().__init__() async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: + + if self.retry_timeout: + + def custom_before_log(retry_logger, log_level): + def custom_log_it(retry_state: 'RetryCallState') -> None: + if retry_state.attempt_number <= 1: + return + msg = 'Retrying execution method %s, attempt %s' + args = (method, retry_state.attempt_number) + retry_logger.log(log_level, msg, *args) + + return custom_log_it + + retry_decorator = retry_aiohttp_errors( + self.retry_timeout, + log_func=custom_before_log, + ) + return await retry_decorator(self.make_request_inner)(method, params) + + return await self.make_request_inner(method, params) + + async def make_request_inner(self, method: RPCEndpoint, params: Any) -> RPCResponse: if self._locker_provider: return await self._locker_provider.make_request(method, params) - for i, provider in enumerate(self._providers): try: response = await provider.make_request(method, params) @@ -73,10 +101,15 @@ def lock_endpoint(self, endpoint_uri: URI | str): finally: self._locker_provider = None + def set_retry_timeout(self, retry_timeout: int): + self.retry_timeout = retry_timeout + -def get_execution_client(endpoints: list[str], is_poa=False, timeout=60) -> AsyncWeb3: +def get_execution_client( + endpoints: list[str], is_poa=False, timeout=60, retry_timeout=0 +) -> AsyncWeb3: provider = ExtendedAsyncHTTPProvider( - endpoint_urls=endpoints, request_kwargs={'timeout': timeout} + endpoint_urls=endpoints, request_kwargs={'timeout': timeout}, retry_timeout=retry_timeout ) client = AsyncWeb3( provider,