From 3ce5f167b689ac83168f6a9ba55881f060f9834f Mon Sep 17 00:00:00 2001 From: Bruce Riley Date: Tue, 10 Sep 2024 13:54:07 -0500 Subject: [PATCH] Node/EVM: Verify EVM chain ID --- node/pkg/watchers/evm/config.go | 4 +- node/pkg/watchers/evm/watcher.go | 57 ++++++++++++++++++-- sdk/evm_chain_ids.go | 93 ++++++++++++++++++++++++++++++++ sdk/evm_chain_ids_test.go | 41 ++++++++++++++ 4 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 sdk/evm_chain_ids.go create mode 100644 sdk/evm_chain_ids_test.go diff --git a/node/pkg/watchers/evm/config.go b/node/pkg/watchers/evm/config.go index 9fe8dfa8b7..92a751a47e 100644 --- a/node/pkg/watchers/evm/config.go +++ b/node/pkg/watchers/evm/config.go @@ -53,9 +53,7 @@ func (wc *WatcherConfig) Create( setWriteC = setC } - var devMode bool = (env == common.UnsafeDevNet) - - watcher := NewEthWatcher(wc.Rpc, eth_common.HexToAddress(wc.Contract), string(wc.NetworkID), wc.ChainID, msgC, setWriteC, obsvReqC, queryReqC, queryResponseC, devMode, wc.CcqBackfillCache) + watcher := NewEthWatcher(wc.Rpc, eth_common.HexToAddress(wc.Contract), string(wc.NetworkID), wc.ChainID, msgC, setWriteC, obsvReqC, queryReqC, queryResponseC, env, wc.CcqBackfillCache) watcher.SetL1Finalizer(wc.l1Finalizer) return watcher, watcher.Run, nil } diff --git a/node/pkg/watchers/evm/watcher.go b/node/pkg/watchers/evm/watcher.go index a6fbf4277a..b8e3ab2515 100644 --- a/node/pkg/watchers/evm/watcher.go +++ b/node/pkg/watchers/evm/watcher.go @@ -5,6 +5,8 @@ import ( "fmt" "math" "math/big" + "strconv" + "strings" "sync" "sync/atomic" "time" @@ -28,6 +30,7 @@ import ( "github.com/certusone/wormhole/node/pkg/query" "github.com/certusone/wormhole/node/pkg/readiness" "github.com/certusone/wormhole/node/pkg/supervisor" + "github.com/wormhole-foundation/wormhole/sdk" "github.com/wormhole-foundation/wormhole/sdk/vaa" ) @@ -121,6 +124,7 @@ type ( // Interface to the chain specific ethereum library. ethConn connectors.Connector + env common.Environment unsafeDevMode bool latestBlockNumber uint64 @@ -163,7 +167,7 @@ func NewEthWatcher( obsvReqC <-chan *gossipv1.ObservationRequest, queryReqC <-chan *query.PerChainQueryInternal, queryResponseC chan<- *query.PerChainQueryResponseInternal, - unsafeDevMode bool, + env common.Environment, ccqBackfillCache bool, ) *Watcher { return &Watcher{ @@ -178,7 +182,8 @@ func NewEthWatcher( queryReqC: queryReqC, queryResponseC: queryResponseC, pending: map[pendingKey]*pendingMessage{}, - unsafeDevMode: unsafeDevMode, + env: env, + unsafeDevMode: (env == common.UnsafeDevNet), ccqConfig: query.GetPerChainConfig(chainID), ccqMaxBlockNumber: big.NewInt(0).SetUint64(math.MaxUint64), ccqBackfillCache: ccqBackfillCache, @@ -211,14 +216,18 @@ func (w *Watcher) Run(parentCtx context.Context) error { ContractAddress: w.contract.Hex(), }) - timeout, cancel := context.WithTimeout(ctx, 15*time.Second) - defer cancel() + if err := w.verifyEvmChainID(ctx, logger); err != nil { + return fmt.Errorf("failed to verify evm chain id: %w", err) + } finalizedPollingSupported, safePollingSupported, err := w.getFinality(ctx) if err != nil { return fmt.Errorf("failed to determine finality: %w", err) } + timeout, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + if finalizedPollingSupported { if safePollingSupported { logger.Info("polling for finalized and safe blocks") @@ -794,6 +803,46 @@ func (w *Watcher) getFinality(ctx context.Context) (bool, bool, error) { return finalized, safe, nil } +// verifyEvmChainID reads the EVM chain ID from the node and verifies that it matches the expected value (making sure we aren't connected to the wrong chain). +func (w *Watcher) verifyEvmChainID(ctx context.Context, logger *zap.Logger) error { + timeout, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + c, err := rpc.DialContext(timeout, w.url) + if err != nil { + return fmt.Errorf("failed to connect to endpoint: %w", err) + } + + var str string + err = c.CallContext(ctx, &str, "eth_chainId") + if err != nil { + return fmt.Errorf("failed to read evm chain id: %w", err) + } + + evmChainID, err := strconv.ParseUint(strings.TrimPrefix(str, "0x"), 16, 64) + if err != nil { + return fmt.Errorf(`eth_chainId returned an invalid int: "%s"`, str) + } + + logger.Info("queried evm chain id", zap.Uint64("evmChainID", evmChainID)) + + if w.unsafeDevMode { + // In devnet we log the result but don't enforce it. + return nil + } + + expectedEvmChainID, err := sdk.GetEvmChainID(string(w.env), w.chainID) + if err != nil { + return fmt.Errorf("failed to look up evm chain id: %w", err) + } + + if evmChainID != uint64(expectedEvmChainID) { + return fmt.Errorf("evm chain ID miss match, expected %d, received %d", expectedEvmChainID, evmChainID) + } + + return nil +} + // SetL1Finalizer is used to set the layer one finalizer. func (w *Watcher) SetL1Finalizer(l1Finalizer interfaces.L1Finalizer) { w.l1Finalizer = l1Finalizer diff --git a/sdk/evm_chain_ids.go b/sdk/evm_chain_ids.go new file mode 100644 index 0000000000..1c513f9cc1 --- /dev/null +++ b/sdk/evm_chain_ids.go @@ -0,0 +1,93 @@ +package sdk + +import ( + "errors" + "strings" + + "github.com/wormhole-foundation/wormhole/sdk/vaa" +) + +type EvmChainIDs map[vaa.ChainID]int + +var mainnetEvmChainIDs = EvmChainIDs{ + vaa.ChainIDAcala: 787, + vaa.ChainIDArbitrum: 42161, + vaa.ChainIDAurora: 1313161554, + vaa.ChainIDAvalanche: 43114, + vaa.ChainIDBSC: 56, + vaa.ChainIDBase: 8453, + vaa.ChainIDBlast: 81457, + vaa.ChainIDCelo: 42220, + vaa.ChainIDEthereum: 1, + vaa.ChainIDFantom: 250, + vaa.ChainIDGnosis: 100, + vaa.ChainIDKarura: 686, + vaa.ChainIDKlaytn: 8217, + vaa.ChainIDLinea: 0, // TODO: We need this value + vaa.ChainIDMantle: 5000, + vaa.ChainIDMoonbeam: 1284, + vaa.ChainIDOasis: 42262, + vaa.ChainIDOptimism: 10, + vaa.ChainIDPolygon: 137, + vaa.ChainIDRootstock: 30, + vaa.ChainIDScroll: 534352, + vaa.ChainIDSnaxchain: 2192, + vaa.ChainIDXLayer: 196, +} + +var testnetEvmChainIDs = EvmChainIDs{ + vaa.ChainIDAcala: 597, + vaa.ChainIDArbitrum: 421613, + vaa.ChainIDArbitrumSepolia: 421614, + vaa.ChainIDAurora: 1313161555, + vaa.ChainIDAvalanche: 43113, + vaa.ChainIDBSC: 97, + vaa.ChainIDBase: 84531, + vaa.ChainIDBaseSepolia: 84532, + vaa.ChainIDBerachain: 80084, + vaa.ChainIDBlast: 168587773, + vaa.ChainIDCelo: 44787, + vaa.ChainIDEthereum: 17000, // This is actually the value for Holesky, since Goerli obsolete. + vaa.ChainIDFantom: 4002, + vaa.ChainIDGnosis: 77, + vaa.ChainIDHolesky: 17000, + vaa.ChainIDKarura: 596, + vaa.ChainIDKlaytn: 1001, + vaa.ChainIDLinea: 59141, + vaa.ChainIDMantle: 5003, + vaa.ChainIDMoonbeam: 1287, + vaa.ChainIDOasis: 42261, + vaa.ChainIDOptimism: 420, + vaa.ChainIDOptimismSepolia: 11155420, + vaa.ChainIDPolygon: 80001, + vaa.ChainIDPolygonSepolia: 80002, + vaa.ChainIDRootstock: 31, + vaa.ChainIDScroll: 534353, + vaa.ChainIDSeiEVM: 713715, + vaa.ChainIDSepolia: 11155111, + vaa.ChainIDSnaxchain: 13001, + vaa.ChainIDXLayer: 195, +} + +var ErrInvalidEnv = errors.New("invalid environment") +var ErrNotFound = errors.New("not found") + +// GetEvmChainID returns the expected EVM chain ID associated with the given Wormhole chain ID and environment passed it. +func GetEvmChainID(env string, chainID vaa.ChainID) (int, error) { + env = strings.ToLower(env) + if env == "prod" || env == "mainnet" { + return getEvmChainID(mainnetEvmChainIDs, chainID) + } + if env == "test" || env == "testnet" { + return getEvmChainID(testnetEvmChainIDs, chainID) + } + return 0, ErrInvalidEnv +} + +func getEvmChainID(evmChains EvmChainIDs, chainID vaa.ChainID) (int, error) { + id, exists := evmChains[chainID] + if !exists { + return 0, ErrNotFound + } + return id, nil +} diff --git a/sdk/evm_chain_ids_test.go b/sdk/evm_chain_ids_test.go new file mode 100644 index 0000000000..57a917f97f --- /dev/null +++ b/sdk/evm_chain_ids_test.go @@ -0,0 +1,41 @@ +package sdk + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wormhole-foundation/wormhole/sdk/vaa" +) + +func TestGetEvmChainID(t *testing.T) { + type test struct { + env string + input vaa.ChainID + output int + err error + } + + // Note: Don't intend to list every chain here, just enough to verify `GetEvmChainID`. + tests := []test{ + {env: "mainnet", input: vaa.ChainIDUnset, output: 0, err: ErrNotFound}, + {env: "mainnet", input: vaa.ChainIDSepolia, output: 0, err: ErrNotFound}, + {env: "mainnet", input: vaa.ChainIDEthereum, output: 1}, + {env: "mainnet", input: vaa.ChainIDArbitrum, output: 42161}, + {env: "testnet", input: vaa.ChainIDSepolia, output: 11155111}, + {env: "testnet", input: vaa.ChainIDEthereum, output: 17000}, + {env: "junk", input: vaa.ChainIDEthereum, output: 17000, err: ErrInvalidEnv}, + } + + for _, tc := range tests { + t.Run(tc.env+"-"+tc.input.String(), func(t *testing.T) { + evmChainID, err := GetEvmChainID(tc.env, tc.input) + if tc.err != nil { + assert.ErrorIs(t, tc.err, err) + } else { + require.NoError(t, err) + assert.Equal(t, tc.output, evmChainID) + } + }) + } +}