Skip to content

Commit

Permalink
Add support for fallback nodes (#45)
Browse files Browse the repository at this point in the history
* Add support for fallback nodes

* Review fixes

* Review fixes

* Use repr for error logging
  • Loading branch information
cyc60 authored Jul 4, 2023
1 parent b7ae821 commit da1a32c
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 45 deletions.
12 changes: 6 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sw-utils"
version = "0.3.10"
version = "0.3.11"
description = "StakeWise Python utils"
authors = ["StakeWise Labs <[email protected]>"]
license = "GPL-3.0-or-later"
Expand Down
4 changes: 4 additions & 0 deletions sw_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ def exit_gracefully(self, signum: int, *args, **kwargs) -> None:
def exit_default(self, signum: int, *args, **kwargs) -> None:
# pylint: disable=unused-argument
raise KeyboardInterrupt


def urljoin(*args):
return '/'.join(map(lambda x: str(x).strip('/'), args))
41 changes: 33 additions & 8 deletions sw_utils/consensus.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import logging
from enum import Enum
from typing import Any, Dict
from typing import Any

from eth_typing import URI
from web3._utils.request import async_json_make_get_request
from web3.beacon import AsyncBeacon

from sw_utils.common import urljoin
from sw_utils.exceptions import AiohttpRecoveredErrors

logger = logging.getLogger(__name__)


GET_VALIDATORS = '/eth/v1/beacon/states/{0}/validators{1}'


Expand Down Expand Up @@ -37,18 +44,36 @@ class ValidatorStatus(Enum):


class ExtendedAsyncBeacon(AsyncBeacon):
def __init__(self, base_url: str, timeout: int = 60) -> None:
super().__init__(base_url)
"""
Provider with support for fallback endpoints.
"""

def __init__(
self,
base_urls: list[str],
timeout: int = 60,
) -> None:
self.base_urls = base_urls
self.timeout = timeout
super().__init__('') # hack origin base_url param

async def get_validators_by_ids(self, validator_ids: list[str], state_id: str = 'head') -> dict:
endpoint = GET_VALIDATORS.format(state_id, f"?id={'&id='.join(validator_ids)}")
return await self._async_make_get_request(endpoint)

async def _async_make_get_request(self, endpoint_uri: str) -> Dict[str, Any]:
uri = URI(self.base_url + endpoint_uri)
return await async_json_make_get_request(uri, timeout=self.timeout)
async def _async_make_get_request(self, endpoint_uri: str) -> dict[str, Any]:
for i, url in enumerate(self.base_urls):
try:
uri = URI(urljoin(url, endpoint_uri))
return await async_json_make_get_request(uri, timeout=self.timeout)

except AiohttpRecoveredErrors as error:
if i == len(self.base_urls) - 1:
raise error
logger.error('%s: %s', url, repr(error))

return {}


def get_consensus_client(endpoint: str, timeout: int = 60) -> ExtendedAsyncBeacon:
return ExtendedAsyncBeacon(base_url=endpoint, timeout=timeout)
def get_consensus_client(endpoints: list[str], timeout: int = 60) -> ExtendedAsyncBeacon:
return ExtendedAsyncBeacon(base_urls=endpoints, timeout=timeout)
35 changes: 8 additions & 27 deletions sw_utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,13 @@
import backoff
import requests

logger = logging.getLogger(__name__)


class RecoverableServerError(Exception):
"""
Wrapper around ClientResponseError for HTTP 500 errors.
Only for internal use inside sw-utils library.
Do not raise `RecoverableServerError` in application code.
"""

def __init__(self, origin: requests.HTTPError | aiohttp.ClientResponseError):
self.origin = origin
if isinstance(origin, requests.HTTPError):
self.status_code = origin.response.status_code
self.uri = origin.response.url
elif isinstance(origin, aiohttp.ClientResponseError):
self.status_code = origin.status
self.uri = origin.request_info
from sw_utils.exceptions import (
AiohttpRecoveredErrors,
RecoverableServerError,
RequestsRecoveredErrors,
)

super().__init__()

def __str__(self):
return (
f'RecoverableServerError (status_code: {self.status_code}, '
f'uri: {self.uri}): {self.origin}'
)
logger = logging.getLogger(__name__)


def wrap_aiohttp_500_errors(f):
Expand Down Expand Up @@ -72,7 +53,7 @@ def backoff_aiohttp_errors(max_tries: int | None = None, max_time: int | None =

backoff_decorator = backoff.on_exception(
backoff.expo,
(aiohttp.ClientConnectionError, RecoverableServerError, asyncio.TimeoutError),
AiohttpRecoveredErrors,
max_tries=max_tries,
max_time=max_time,
**kwargs,
Expand Down Expand Up @@ -137,7 +118,7 @@ def backoff_requests_errors(max_tries: int | None = None, max_time: int | None =

backoff_decorator = backoff.on_exception(
backoff.expo,
(requests.ConnectionError, requests.Timeout, RecoverableServerError),
RequestsRecoveredErrors,
max_tries=max_tries,
max_time=max_time,
**kwargs,
Expand Down
44 changes: 44 additions & 0 deletions sw_utils/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import asyncio
import logging

import aiohttp
import requests

logger = logging.getLogger(__name__)


class RecoverableServerError(Exception):
"""
Wrapper around ClientResponseError for HTTP 500 errors.
Only for internal use inside sw-utils library.
Do not raise `RecoverableServerError` in application code.
"""

def __init__(self, origin: requests.HTTPError | aiohttp.ClientResponseError):
self.origin = origin
if isinstance(origin, requests.HTTPError):
self.status_code = origin.response.status_code
self.uri = origin.response.url
elif isinstance(origin, aiohttp.ClientResponseError):
self.status_code = origin.status
self.uri = origin.request_info

super().__init__()

def __str__(self):
return (
f'RecoverableServerError (status_code: {self.status_code}, '
f'uri: {self.uri}): {self.origin}'
)


AiohttpRecoveredErrors = (
aiohttp.ClientConnectionError,
RecoverableServerError,
asyncio.TimeoutError,
)
RequestsRecoveredErrors = (
requests.ConnectionError,
requests.Timeout,
RecoverableServerError,
)
58 changes: 55 additions & 3 deletions sw_utils/execution.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,67 @@
import logging
from typing import Any

from web3 import AsyncHTTPProvider, AsyncWeb3
from eth_typing import URI
from web3 import AsyncWeb3
from web3.eth import AsyncEth
from web3.middleware import async_geth_poa_middleware
from web3.net import AsyncNet
from web3.providers.async_rpc import AsyncHTTPProvider
from web3.types import RPCEndpoint, RPCResponse

from sw_utils.exceptions import AiohttpRecoveredErrors

logger = logging.getLogger(__name__)


def get_execution_client(endpoint: str, is_poa=False, timeout=60) -> AsyncWeb3:
provider = AsyncHTTPProvider(endpoint, request_kwargs={'timeout': timeout})
class ProtocolNotSupported(Exception):
"""Supported protocols: http, https"""


class ExtendedAsyncHTTPProvider(AsyncHTTPProvider):
"""
Provider with support for fallback endpoints.
"""

_providers: list[AsyncHTTPProvider] = []

def __init__(
self,
endpoint_urls: list[str],
request_kwargs: Any | None = None,
):
self._endpoint_urls = endpoint_urls
self._providers = []

if endpoint_urls:
self.endpoint_uri = URI(endpoint_urls[0])

for host_uri in endpoint_urls:
if host_uri.startswith('http'):
self._providers.append(AsyncHTTPProvider(host_uri, request_kwargs))
else:
protocol = host_uri.split('://')[0]
raise ProtocolNotSupported(f'Protocol "{protocol}" is not supported.')

super().__init__()

async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
for i, provider in enumerate(self._providers):
try:
response = await provider.make_request(method, params)
return response
except AiohttpRecoveredErrors as error:
if i == len(self._providers) - 1:
raise error
logger.error('%s: %s', provider.endpoint_uri, repr(error))

return {}


def get_execution_client(endpoints: list[str], is_poa=False, timeout=60) -> AsyncWeb3:
provider = ExtendedAsyncHTTPProvider(
endpoint_urls=endpoints, request_kwargs={'timeout': timeout}
)
client = AsyncWeb3(
provider,
modules={'eth': (AsyncEth,), 'net': AsyncNet},
Expand Down

0 comments on commit da1a32c

Please sign in to comment.