Skip to content

Commit

Permalink
Node/EVM: Verify EVM chain ID
Browse files Browse the repository at this point in the history
  • Loading branch information
bruce-riley committed Sep 10, 2024
1 parent 6b810ac commit 3ce5f16
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 7 deletions.
4 changes: 1 addition & 3 deletions node/pkg/watchers/evm/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ func (wc *WatcherConfig) Create(
setWriteC = setC
}

var devMode bool = (env == common.UnsafeDevNet)

watcher := NewEthWatcher(wc.Rpc, eth_common.HexToAddress(wc.Contract), string(wc.NetworkID), wc.ChainID, msgC, setWriteC, obsvReqC, queryReqC, queryResponseC, devMode, wc.CcqBackfillCache)
watcher := NewEthWatcher(wc.Rpc, eth_common.HexToAddress(wc.Contract), string(wc.NetworkID), wc.ChainID, msgC, setWriteC, obsvReqC, queryReqC, queryResponseC, env, wc.CcqBackfillCache)
watcher.SetL1Finalizer(wc.l1Finalizer)
return watcher, watcher.Run, nil
}
57 changes: 53 additions & 4 deletions node/pkg/watchers/evm/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"fmt"
"math"
"math/big"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
Expand All @@ -28,6 +30,7 @@ import (
"github.com/certusone/wormhole/node/pkg/query"
"github.com/certusone/wormhole/node/pkg/readiness"
"github.com/certusone/wormhole/node/pkg/supervisor"
"github.com/wormhole-foundation/wormhole/sdk"
"github.com/wormhole-foundation/wormhole/sdk/vaa"
)

Expand Down Expand Up @@ -121,6 +124,7 @@ type (

// Interface to the chain specific ethereum library.
ethConn connectors.Connector
env common.Environment
unsafeDevMode bool

latestBlockNumber uint64
Expand Down Expand Up @@ -163,7 +167,7 @@ func NewEthWatcher(
obsvReqC <-chan *gossipv1.ObservationRequest,
queryReqC <-chan *query.PerChainQueryInternal,
queryResponseC chan<- *query.PerChainQueryResponseInternal,
unsafeDevMode bool,
env common.Environment,
ccqBackfillCache bool,
) *Watcher {
return &Watcher{
Expand All @@ -178,7 +182,8 @@ func NewEthWatcher(
queryReqC: queryReqC,
queryResponseC: queryResponseC,
pending: map[pendingKey]*pendingMessage{},
unsafeDevMode: unsafeDevMode,
env: env,
unsafeDevMode: (env == common.UnsafeDevNet),
ccqConfig: query.GetPerChainConfig(chainID),
ccqMaxBlockNumber: big.NewInt(0).SetUint64(math.MaxUint64),
ccqBackfillCache: ccqBackfillCache,
Expand Down Expand Up @@ -211,14 +216,18 @@ func (w *Watcher) Run(parentCtx context.Context) error {
ContractAddress: w.contract.Hex(),
})

timeout, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
if err := w.verifyEvmChainID(ctx, logger); err != nil {
return fmt.Errorf("failed to verify evm chain id: %w", err)
}

finalizedPollingSupported, safePollingSupported, err := w.getFinality(ctx)
if err != nil {
return fmt.Errorf("failed to determine finality: %w", err)
}

timeout, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()

if finalizedPollingSupported {
if safePollingSupported {
logger.Info("polling for finalized and safe blocks")
Expand Down Expand Up @@ -794,6 +803,46 @@ func (w *Watcher) getFinality(ctx context.Context) (bool, bool, error) {
return finalized, safe, nil
}

// verifyEvmChainID reads the EVM chain ID from the node and verifies that it matches the expected value (making sure we aren't connected to the wrong chain).
func (w *Watcher) verifyEvmChainID(ctx context.Context, logger *zap.Logger) error {
timeout, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()

c, err := rpc.DialContext(timeout, w.url)
if err != nil {
return fmt.Errorf("failed to connect to endpoint: %w", err)
}

var str string
err = c.CallContext(ctx, &str, "eth_chainId")
if err != nil {
return fmt.Errorf("failed to read evm chain id: %w", err)
}

evmChainID, err := strconv.ParseUint(strings.TrimPrefix(str, "0x"), 16, 64)
if err != nil {
return fmt.Errorf(`eth_chainId returned an invalid int: "%s"`, str)
}

logger.Info("queried evm chain id", zap.Uint64("evmChainID", evmChainID))

if w.unsafeDevMode {
// In devnet we log the result but don't enforce it.
return nil
}

expectedEvmChainID, err := sdk.GetEvmChainID(string(w.env), w.chainID)
if err != nil {
return fmt.Errorf("failed to look up evm chain id: %w", err)
}

if evmChainID != uint64(expectedEvmChainID) {
return fmt.Errorf("evm chain ID miss match, expected %d, received %d", expectedEvmChainID, evmChainID)
}

return nil
}

// SetL1Finalizer is used to set the layer one finalizer.
func (w *Watcher) SetL1Finalizer(l1Finalizer interfaces.L1Finalizer) {
w.l1Finalizer = l1Finalizer
Expand Down
93 changes: 93 additions & 0 deletions sdk/evm_chain_ids.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package sdk

import (
"errors"
"strings"

"github.com/wormhole-foundation/wormhole/sdk/vaa"
)

type EvmChainIDs map[vaa.ChainID]int

var mainnetEvmChainIDs = EvmChainIDs{
vaa.ChainIDAcala: 787,
vaa.ChainIDArbitrum: 42161,
vaa.ChainIDAurora: 1313161554,
vaa.ChainIDAvalanche: 43114,
vaa.ChainIDBSC: 56,
vaa.ChainIDBase: 8453,
vaa.ChainIDBlast: 81457,
vaa.ChainIDCelo: 42220,
vaa.ChainIDEthereum: 1,
vaa.ChainIDFantom: 250,
vaa.ChainIDGnosis: 100,
vaa.ChainIDKarura: 686,
vaa.ChainIDKlaytn: 8217,
vaa.ChainIDLinea: 0, // TODO: We need this value
vaa.ChainIDMantle: 5000,
vaa.ChainIDMoonbeam: 1284,
vaa.ChainIDOasis: 42262,
vaa.ChainIDOptimism: 10,
vaa.ChainIDPolygon: 137,
vaa.ChainIDRootstock: 30,
vaa.ChainIDScroll: 534352,
vaa.ChainIDSnaxchain: 2192,
vaa.ChainIDXLayer: 196,
}

var testnetEvmChainIDs = EvmChainIDs{
vaa.ChainIDAcala: 597,
vaa.ChainIDArbitrum: 421613,
vaa.ChainIDArbitrumSepolia: 421614,
vaa.ChainIDAurora: 1313161555,
vaa.ChainIDAvalanche: 43113,
vaa.ChainIDBSC: 97,
vaa.ChainIDBase: 84531,
vaa.ChainIDBaseSepolia: 84532,
vaa.ChainIDBerachain: 80084,
vaa.ChainIDBlast: 168587773,
vaa.ChainIDCelo: 44787,
vaa.ChainIDEthereum: 17000, // This is actually the value for Holesky, since Goerli obsolete.
vaa.ChainIDFantom: 4002,
vaa.ChainIDGnosis: 77,
vaa.ChainIDHolesky: 17000,
vaa.ChainIDKarura: 596,
vaa.ChainIDKlaytn: 1001,
vaa.ChainIDLinea: 59141,
vaa.ChainIDMantle: 5003,
vaa.ChainIDMoonbeam: 1287,
vaa.ChainIDOasis: 42261,
vaa.ChainIDOptimism: 420,
vaa.ChainIDOptimismSepolia: 11155420,
vaa.ChainIDPolygon: 80001,
vaa.ChainIDPolygonSepolia: 80002,
vaa.ChainIDRootstock: 31,
vaa.ChainIDScroll: 534353,
vaa.ChainIDSeiEVM: 713715,
vaa.ChainIDSepolia: 11155111,
vaa.ChainIDSnaxchain: 13001,
vaa.ChainIDXLayer: 195,
}

var ErrInvalidEnv = errors.New("invalid environment")
var ErrNotFound = errors.New("not found")

// GetEvmChainID returns the expected EVM chain ID associated with the given Wormhole chain ID and environment passed it.
func GetEvmChainID(env string, chainID vaa.ChainID) (int, error) {
env = strings.ToLower(env)
if env == "prod" || env == "mainnet" {
return getEvmChainID(mainnetEvmChainIDs, chainID)
}
if env == "test" || env == "testnet" {
return getEvmChainID(testnetEvmChainIDs, chainID)
}
return 0, ErrInvalidEnv
}

func getEvmChainID(evmChains EvmChainIDs, chainID vaa.ChainID) (int, error) {
id, exists := evmChains[chainID]
if !exists {
return 0, ErrNotFound
}
return id, nil
}
41 changes: 41 additions & 0 deletions sdk/evm_chain_ids_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sdk

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/wormhole-foundation/wormhole/sdk/vaa"
)

func TestGetEvmChainID(t *testing.T) {
type test struct {
env string
input vaa.ChainID
output int
err error
}

// Note: Don't intend to list every chain here, just enough to verify `GetEvmChainID`.
tests := []test{
{env: "mainnet", input: vaa.ChainIDUnset, output: 0, err: ErrNotFound},
{env: "mainnet", input: vaa.ChainIDSepolia, output: 0, err: ErrNotFound},
{env: "mainnet", input: vaa.ChainIDEthereum, output: 1},
{env: "mainnet", input: vaa.ChainIDArbitrum, output: 42161},
{env: "testnet", input: vaa.ChainIDSepolia, output: 11155111},
{env: "testnet", input: vaa.ChainIDEthereum, output: 17000},
{env: "junk", input: vaa.ChainIDEthereum, output: 17000, err: ErrInvalidEnv},
}

for _, tc := range tests {
t.Run(tc.env+"-"+tc.input.String(), func(t *testing.T) {
evmChainID, err := GetEvmChainID(tc.env, tc.input)
if tc.err != nil {
assert.ErrorIs(t, tc.err, err)
} else {
require.NoError(t, err)
assert.Equal(t, tc.output, evmChainID)
}
})
}
}

0 comments on commit 3ce5f16

Please sign in to comment.