Skip to content

Commit

Permalink
Move retries to clients
Browse files Browse the repository at this point in the history
  • Loading branch information
cyc60 committed Jul 20, 2023
1 parent 9026771 commit edb313d
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 11 deletions.
42 changes: 38 additions & 4 deletions sw_utils/consensus.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING, Any

import aiohttp
from eth_typing import URI, HexStr
Expand All @@ -9,6 +9,7 @@
from web3.beacon.api_endpoints import GET_VOLUNTARY_EXITS

from sw_utils.common import urljoin
from sw_utils.decorators import retry_aiohttp_errors
from sw_utils.exceptions import AiohttpRecoveredErrors

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -44,6 +45,9 @@ class ValidatorStatus(Enum):
ValidatorStatus.WITHDRAWAL_DONE,
]

if TYPE_CHECKING:
from tenacity import RetryCallState


class ExtendedAsyncBeacon(AsyncBeacon):
"""
Expand All @@ -58,11 +62,12 @@ def __init__(
base_urls: list[str],
timeout: int = 60,
session: aiohttp.ClientSession = None,
retry_timeout: int = 0,
) -> None:
self.base_urls = base_urls
self.timeout = timeout
self.session = session

self.retry_timeout = retry_timeout
super().__init__('') # hack origin base_url param

async def get_validators_by_ids(self, validator_ids: list[str], state_id: str = 'head') -> dict:
Expand Down Expand Up @@ -91,6 +96,27 @@ async def submit_voluntary_exit(
logger.error('%s: %s', url, repr(error))

async def _async_make_get_request(self, endpoint_uri: str) -> dict[str, Any]:
if self.retry_timeout:

def custom_before_log(retry_logger, log_level):
def custom_log_it(retry_state: 'RetryCallState') -> None:
if retry_state.attempt_number <= 1:
return
msg = 'Retrying consensus uri %s(), attempt %s'
args = (endpoint_uri, retry_state.attempt_number) # type: ignore
retry_logger.log(log_level, msg, *args)

return custom_log_it

retry_decorator = retry_aiohttp_errors(
self.retry_timeout,
log_func=custom_before_log,
)
return await retry_decorator(self._async_make_get_request_inner)(endpoint_uri)

return await self._async_make_get_request_inner(endpoint_uri)

async def _async_make_get_request_inner(self, endpoint_uri: str) -> dict[str, Any]:
for i, url in enumerate(self.base_urls):
try:
uri = URI(urljoin(url, endpoint_uri))
Expand All @@ -114,8 +140,16 @@ async def _make_session_get_request(self, uri):
data = await response.json()
return data

def set_retry_timeout(self, retry_timeout: int):
self.retry_timeout = retry_timeout


def get_consensus_client(
endpoints: list[str], timeout: int = 60, session: aiohttp.ClientSession = None
endpoints: list[str],
timeout: int = 60,
session: aiohttp.ClientSession = None,
retry_timeout: int = 0,
) -> ExtendedAsyncBeacon:
return ExtendedAsyncBeacon(base_urls=endpoints, timeout=timeout, session=session)
return ExtendedAsyncBeacon(
base_urls=endpoints, timeout=timeout, session=session, retry_timeout=retry_timeout
)
6 changes: 3 additions & 3 deletions sw_utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def custom_before_log(logger, log_level):
def custom_log_it(retry_state: 'RetryCallState') -> None:
if retry_state.attempt_number <= 1:
return
msg = 'Retrying %s(), attempt %s'
msg = 'Retrying %s, attempt %s'
args = (retry_state.fn.__name__, retry_state.attempt_number) # type: ignore
logger.log(log_level, msg, *args)

Expand All @@ -33,10 +33,10 @@ def can_be_retried_aiohttp_error(e: BaseException) -> bool:
return False


def retry_aiohttp_errors(delay: int = 60):
def retry_aiohttp_errors(delay: int = 60, log_func=custom_before_log):
return retry(
retry=retry_if_exception(can_be_retried_aiohttp_error),
wait=wait_exponential(multiplier=1, min=1, max=delay // 2),
stop=stop_after_delay(delay),
before=custom_before_log(default_logger, logging.INFO),
before=log_func(default_logger, logging.INFO),
)
40 changes: 36 additions & 4 deletions sw_utils/execution.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import logging
from typing import Any
from typing import TYPE_CHECKING, Any

from eth_typing import URI
from web3 import AsyncWeb3
Expand All @@ -10,11 +10,16 @@
from web3.providers.async_rpc import AsyncHTTPProvider
from web3.types import RPCEndpoint, RPCResponse

from sw_utils.decorators import retry_aiohttp_errors
from sw_utils.exceptions import AiohttpRecoveredErrors

logger = logging.getLogger(__name__)


if TYPE_CHECKING:
from tenacity import RetryCallState


class ProtocolNotSupported(Exception):
"""Supported protocols: http, https"""

Expand All @@ -31,9 +36,11 @@ def __init__(
self,
endpoint_urls: list[str],
request_kwargs: Any | None = None,
retry_timeout: int = 0,
):
self._endpoint_urls = endpoint_urls
self._providers = []
self.retry_timeout = retry_timeout

if endpoint_urls:
self.endpoint_uri = URI(endpoint_urls[0])
Expand All @@ -48,9 +55,29 @@ def __init__(
super().__init__()

async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:

if self.retry_timeout:
def custom_before_log(retry_logger, log_level):
def custom_log_it(retry_state: 'RetryCallState') -> None:
if retry_state.attempt_number <= 1:
return
msg = 'Retrying execution method %s, attempt %s'
args = (method, retry_state.attempt_number) # type: ignore
retry_logger.log(log_level, msg, *args)

return custom_log_it

retry_decorator = retry_aiohttp_errors(
self.retry_timeout,
log_func=custom_before_log,
)
return await retry_decorator(self.make_request_inner)(method, params)

return await self.make_request_inner(method, params)

async def make_request_inner(self, method: RPCEndpoint, params: Any) -> RPCResponse:
if self._locker_provider:
return await self._locker_provider.make_request(method, params)

for i, provider in enumerate(self._providers):
try:
response = await provider.make_request(method, params)
Expand All @@ -73,10 +100,15 @@ def lock_endpoint(self, endpoint_uri: URI | str):
finally:
self._locker_provider = None

def set_retry_timeout(self, retry_timeout: int):
self.retry_timeout = retry_timeout


def get_execution_client(endpoints: list[str], is_poa=False, timeout=60) -> AsyncWeb3:
def get_execution_client(
endpoints: list[str], is_poa=False, timeout=60, retry_timeout=0
) -> AsyncWeb3:
provider = ExtendedAsyncHTTPProvider(
endpoint_urls=endpoints, request_kwargs={'timeout': timeout}
endpoint_urls=endpoints, request_kwargs={'timeout': timeout}, retry_timeout=retry_timeout
)
client = AsyncWeb3(
provider,
Expand Down

0 comments on commit edb313d

Please sign in to comment.