Skip to content

Commit

Permalink
Add black (#40)
Browse files Browse the repository at this point in the history
* Add black

* Change CI - on push only
  • Loading branch information
evgeny-stakewise authored Jun 16, 2023
1 parent e7356b0 commit 06981fc
Show file tree
Hide file tree
Showing 25 changed files with 208 additions and 248 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: CI

on: [pull_request, push]
on: [ push ]

jobs:
pre-commit:
Expand Down
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,17 @@ repos:
files: no-files
args: ["lock", "--check"]
always_run: true

- repo: local
hooks:
- id: black
name: black
entry: black
language: system
types: [ python ]
args:
[
"--check",
"--diff",
]
require_serial: true
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ ignored-modules=["milagro_bls_binding"]
max-line-length = 100
select = ["E121"]

[tool.isort]
profile = "black"

[tool.mypy]
ignore_missing_imports = true
python_version = "3.10"
Expand Down
37 changes: 27 additions & 10 deletions sw_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,31 @@
from .common import InterruptHandler
from .consensus import (PENDING_STATUSES, ExtendedAsyncBeacon, ValidatorStatus,
get_consensus_client)
from .consensus import (
PENDING_STATUSES,
ExtendedAsyncBeacon,
ValidatorStatus,
get_consensus_client,
)
from .event_scanner import EventProcessor, EventScanner
from .execution import get_execution_client
from .ipfs import (BaseUploadClient, IpfsException, IpfsFetchClient,
IpfsMultiUploadClient, IpfsUploadClient, PinataUploadClient,
WebStorageClient)
from .ipfs import (
BaseUploadClient,
IpfsException,
IpfsFetchClient,
IpfsMultiUploadClient,
IpfsUploadClient,
PinataUploadClient,
WebStorageClient,
)
from .middlewares import construct_async_sign_and_send_raw_middleware
from .signing import (DepositData, DepositMessage, compute_deposit_data,
compute_deposit_domain, compute_deposit_message,
compute_signing_root, get_eth1_withdrawal_credentials,
get_exit_message_signing_root,
is_valid_deposit_data_signature, is_valid_exit_signature)
from .signing import (
DepositData,
DepositMessage,
compute_deposit_data,
compute_deposit_domain,
compute_deposit_message,
compute_signing_root,
get_eth1_withdrawal_credentials,
get_exit_message_signing_root,
is_valid_deposit_data_signature,
is_valid_exit_signature,
)
10 changes: 2 additions & 8 deletions sw_utils/consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,11 @@ class ValidatorStatus(Enum):


class ExtendedAsyncBeacon(AsyncBeacon):
def __init__(
self,
base_url: str,
timeout: int = 60
) -> None:
def __init__(self, base_url: str, timeout: int = 60) -> None:
super().__init__(base_url)
self.timeout = timeout

async def get_validators_by_ids(
self, validator_ids: list[str], state_id: str = 'head'
) -> dict:
async def get_validators_by_ids(self, validator_ids: list[str], state_id: str = 'head') -> dict:
endpoint = GET_VALIDATORS.format(state_id, f"?id={'&id='.join(validator_ids)}")
return await self._async_make_get_request(endpoint)

Expand Down
32 changes: 18 additions & 14 deletions sw_utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class RecoverableServerError(Exception):
Only for internal use inside sw-utils library.
Do not raise `RecoverableServerError` in application code.
"""

def __init__(self, origin: requests.HTTPError | aiohttp.ClientResponseError):
self.origin = origin
if isinstance(origin, requests.HTTPError):
Expand All @@ -27,15 +28,18 @@ def __init__(self, origin: requests.HTTPError | aiohttp.ClientResponseError):
super().__init__()

def __str__(self):
return (f'RecoverableServerError (status_code: {self.status_code}, '
f'uri: {self.uri}): {self.origin}')
return (
f'RecoverableServerError (status_code: {self.status_code}, '
f'uri: {self.uri}): {self.origin}'
)


def wrap_aiohttp_500_errors(f):
"""
Allows to distinguish between HTTP 400 and HTTP 500 errors.
Both are represented by `aiohttp.ClientResponseError`.
"""

@functools.wraps(f)
async def wrapper(*args, **kwargs):
try:
Expand All @@ -44,14 +48,11 @@ async def wrapper(*args, **kwargs):
if e.status >= 500:
raise RecoverableServerError(e) from e
raise

return wrapper


def backoff_aiohttp_errors(
max_tries: int | None = None,
max_time: int | None = None,
**kwargs
):
def backoff_aiohttp_errors(max_tries: int | None = None, max_time: int | None = None, **kwargs):
"""
Can be used for:
* retrying web3 api calls
Expand All @@ -74,7 +75,7 @@ def backoff_aiohttp_errors(
(aiohttp.ClientConnectionError, RecoverableServerError, asyncio.TimeoutError),
max_tries=max_tries,
max_time=max_time,
**kwargs
**kwargs,
)

def decorator(f):
Expand All @@ -86,6 +87,7 @@ async def wrapper(*args, **kwargs):
raise e.origin

return wrapper

return decorator


Expand All @@ -95,6 +97,7 @@ def wrap_requests_500_errors(f):
Both are represented by `requests.HTTPError`.
"""
if asyncio.iscoroutinefunction(f):

@functools.wraps(f)
async def async_wrapper(*args, **kwargs):
try:
Expand All @@ -103,6 +106,7 @@ async def async_wrapper(*args, **kwargs):
if e.response.status >= 500:
raise RecoverableServerError(e) from e
raise

return async_wrapper

@functools.wraps(f)
Expand All @@ -113,14 +117,11 @@ def wrapper(*args, **kwargs):
if e.response.status >= 500:
raise RecoverableServerError(e) from e
raise

return wrapper


def backoff_requests_errors(
max_tries: int | None = None,
max_time: int | None = None,
**kwargs
):
def backoff_requests_errors(max_tries: int | None = None, max_time: int | None = None, **kwargs):
"""
DO NOT use `backoff_requests_errors` for handling errors in IpfsFetchClient
or IpfsMultiUploadClient.
Expand All @@ -139,17 +140,19 @@ def backoff_requests_errors(
(requests.ConnectionError, requests.Timeout, RecoverableServerError),
max_tries=max_tries,
max_time=max_time,
**kwargs
**kwargs,
)

def decorator(f):
if asyncio.iscoroutinefunction(f):

@functools.wraps(f)
async def async_wrapper(*args, **kwargs):
try:
return await backoff_decorator(wrap_requests_500_errors(f))(*args, **kwargs)
except RecoverableServerError as e:
raise e.origin

return async_wrapper

@functools.wraps(f)
Expand All @@ -158,6 +161,7 @@ def wrapper(*args, **kwargs):
return backoff_decorator(wrap_requests_500_errors(f))(*args, **kwargs)
except RecoverableServerError as e:
raise e.origin

return wrapper

return decorator
1 change: 1 addition & 0 deletions sw_utils/event_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class EventProcessor(ABC):
"""
Processor of the events.
"""

contract: AsyncContract
contract_event: str

Expand Down
18 changes: 7 additions & 11 deletions sw_utils/ipfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ async def remove(self, ipfs_hash: str) -> None:

class IpfsFetchClient:
def __init__(
self,
endpoints: list[str],
timeout: int = 60,
self,
endpoints: list[str],
timeout: int = 60,
):
self.endpoints = endpoints
self.timeout = timeout
Expand All @@ -268,15 +268,13 @@ async def fetch_bytes(self, ipfs_hash: str) -> bytes:

async def _http_gateway_fetch_bytes(self, endpoint: str, ipfs_hash: str) -> bytes:
async with ClientSession(timeout=ClientTimeout(self.timeout)) as session:
async with session.get(
f"{endpoint.rstrip('/')}/ipfs/{ipfs_hash}"
) as response:
async with session.get(f"{endpoint.rstrip('/')}/ipfs/{ipfs_hash}") as response:
response.raise_for_status()
return await response.read()

def _ipfs_fetch_bytes(self, endpoint: str, ipfs_hash: str) -> bytes:
with ipfshttpclient.connect(
endpoint,
endpoint,
) as client:
return client.cat(ipfs_hash, timeout=self.timeout)

Expand All @@ -299,15 +297,13 @@ async def fetch_json(self, ipfs_hash: str) -> dict | list:

async def _http_gateway_fetch_json(self, endpoint: str, ipfs_hash: str) -> dict | list:
async with ClientSession(timeout=ClientTimeout(self.timeout)) as session:
async with session.get(
f"{endpoint.rstrip('/')}/ipfs/{ipfs_hash}"
) as response:
async with session.get(f"{endpoint.rstrip('/')}/ipfs/{ipfs_hash}") as response:
response.raise_for_status()
return await response.json()

def _ipfs_fetch_json(self, endpoint: str, ipfs_hash: str) -> dict | list:
with ipfshttpclient.connect(
endpoint,
endpoint,
) as client:
return client.get_json(ipfs_hash, timeout=self.timeout)

Expand Down
4 changes: 1 addition & 3 deletions sw_utils/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from web3 import Web3
from web3._utils.async_transactions import fill_transaction_defaults
from web3.middleware.signing import format_transaction, gen_normalized_accounts
from web3.types import (AsyncMiddleware, Middleware, RPCEndpoint, RPCResponse,
TxParams)
from web3.types import AsyncMiddleware, Middleware, RPCEndpoint, RPCResponse, TxParams

_PrivateKey = Union[LocalAccount, PrivateKey, HexStr, bytes]

Expand Down Expand Up @@ -48,7 +47,6 @@ def construct_async_sign_and_send_raw_middleware(
async def sign_and_send_raw_middleware(
make_request: Callable[[RPCEndpoint, Any], Any], _async_w3: 'Web3'
) -> AsyncMiddleware:

async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
if method != 'eth_sendTransaction':
return await make_request(method, params)
Expand Down
13 changes: 5 additions & 8 deletions sw_utils/signing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,18 @@ def compute_deposit_message(
public_key: bytes, withdrawal_credentials: bytes, amount_gwei: int
) -> DepositMessage:
return DepositMessage(
pubkey=public_key,
withdrawal_credentials=withdrawal_credentials,
amount=amount_gwei
pubkey=public_key, withdrawal_credentials=withdrawal_credentials, amount=amount_gwei
)


def compute_deposit_data(
public_key: bytes, withdrawal_credentials: bytes, amount_gwei: int, signature: bytes
) -> DepositData:
return DepositData(
pubkey=public_key,
withdrawal_credentials=withdrawal_credentials,
amount=amount_gwei,
signature=signature
signature=signature,
)


Expand Down Expand Up @@ -96,7 +95,7 @@ def is_valid_exit_signature(
public_key: BLSPubkey,
signature: BLSSignature,
genesis_validators_root: Bytes32,
fork: ConsensusFork
fork: ConsensusFork,
) -> bool:
"""Checks whether exit signature is valid."""
# pylint: disable=protected-access
Expand All @@ -109,9 +108,7 @@ def is_valid_exit_signature(


def get_exit_message_signing_root(
validator_index: int,
genesis_validators_root: Bytes32,
fork: ConsensusFork
validator_index: int, genesis_validators_root: Bytes32, fork: ConsensusFork
) -> bytes:
"""Signs exit message."""
domain = _compute_exit_domain(genesis_validators_root, fork.version)
Expand Down
26 changes: 22 additions & 4 deletions sw_utils/ssz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
from .cache import SSZCache # noqa: F401
from .exceptions import DeserializationError # noqa: F401
from .exceptions import SerializationError, SSZException
from .sedes import (BaseSedes, BasicSedes, Byte, ByteVector, # noqa: F401
Container, ProperCompositeSedes, Serializable, UInt,
Vector, byte_vector, bytes4, bytes32, bytes48, bytes96,
uint8, uint16, uint32, uint64, uint128, uint256)
from .sedes import ByteVector # noqa: F401
from .sedes import (
BaseSedes,
BasicSedes,
Byte,
Container,
ProperCompositeSedes,
Serializable,
UInt,
Vector,
byte_vector,
bytes4,
bytes32,
bytes48,
bytes96,
uint8,
uint16,
uint32,
uint64,
uint128,
uint256,
)
4 changes: 1 addition & 3 deletions sw_utils/ssz/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def persistent(self) -> TStructure:
...


class ResizableHashableStructureEvolverAPI(
HashableStructureEvolverAPI[TStructure, TElement]
):
class ResizableHashableStructureEvolverAPI(HashableStructureEvolverAPI[TStructure, TElement]):
@abstractmethod
def append(self, element: TElement) -> None:
...
Expand Down
4 changes: 1 addition & 3 deletions sw_utils/ssz/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __iter__(self) -> Iterator[bytes]:
raise NotImplementedError('By default, DB classes cannot be iterated.')

def __len__(self) -> int:
raise NotImplementedError(
'By default, classes cannot return the total number of keys.'
)
raise NotImplementedError('By default, classes cannot return the total number of keys.')

@property
def cache_size(self) -> int:
Expand Down
Loading

0 comments on commit 06981fc

Please sign in to comment.