diff --git a/core/block_validator_test.go b/core/block_validator_test.go index 8ee925211..aeb25b0a2 100644 --- a/core/block_validator_test.go +++ b/core/block_validator_test.go @@ -51,6 +51,31 @@ func testChainConfig() *ChainConfig { }, }, }, + { + Name: "Diehard", + Block: big.NewInt(5), + Features: []*ForkFeature{ + { + ID: "eip155", + Options: ChainFeatureConfigOptions{ + "chainID": 62, + }, + }, + { // ecip1010 bomb delay + ID: "gastable", + Options: ChainFeatureConfigOptions{ + "type": "eip160", + }, + }, + { // ecip1010 bomb delay + ID: "difficulty", + Options: ChainFeatureConfigOptions{ + "type": "ecip1010", + "length": 2000000, + }, + }, + }, + }, }, } } diff --git a/core/config_test.go b/core/config_test.go index 4acecdcbc..b6b5cf518 100644 --- a/core/config_test.go +++ b/core/config_test.go @@ -292,6 +292,7 @@ func TestChainConfig_GetFeature4_WorkForHighNumbers(t *testing.T) { } func TestChainConfig_GetChainID(t *testing.T) { + // Test default hardcoded configs. if DefaultConfig.GetChainID().Cmp(DefaultChainConfigChainID) != 0 { t.Error("got: %v, want: %v", DefaultConfig.GetChainID(), DefaultTestnetChainConfigChainID) } @@ -299,17 +300,31 @@ func TestChainConfig_GetChainID(t *testing.T) { t.Error("got: %v, want: %v", TestConfig.GetChainID(), DefaultTestnetChainConfigChainID) } - // Test parsing default external config. - p, e := filepath.Abs("../cmd/geth/config/mainnet.json") - if e != nil { - t.Errorf("filepath err: %v", e) + // If no chainID (config is empty) returns 0. + c := &ChainConfig{} + cid := c.GetChainID() + // check is zero + if cid.Cmp(new(big.Int)) != 0 { + t.Errorf("got: %v, want: %v", cid, new(big.Int)) } - extConfig, err := ReadExternalChainConfig(p) - if err != nil { - t.Errorf("could not find file: %v", err) + + // Test parsing default external mainnet config. + cases := map[string]*big.Int{ + "../cmd/geth/config/mainnet.json": DefaultChainConfigChainID, + "../cmd/geth/config/testnet.json": DefaultTestnetChainConfigChainID, } - if extConfig.ChainConfig.GetChainID().Cmp(big.NewInt(61)) != 0 { - t.Error("found 0 chainid for eip155") + for extConfigPath, wantInt := range cases { + p, e := filepath.Abs(extConfigPath) + if e != nil { + t.Errorf("filepath err: %v", e) + } + extConfig, err := ReadExternalChainConfig(p) + if err != nil { + t.Errorf("could not find file: %v", err) + } + if extConfig.ChainConfig.GetChainID().Cmp(wantInt) != 0 { + t.Error("got: %v, want: %v", extConfig.ChainConfig.GetChainID(), wantInt) + } } } diff --git a/core/types/transaction_signing.go b/core/types/transaction_signing.go index 2b79ad5e9..f3c672607 100644 --- a/core/types/transaction_signing.go +++ b/core/types/transaction_signing.go @@ -134,7 +134,14 @@ func NewChainIdSigner(chainId *big.Int) ChainIdSigner { func (s ChainIdSigner) Equal(s2 Signer) bool { other, ok := s2.(ChainIdSigner) - return ok && other.chainId.Cmp(s.chainId) == 0 + if !ok { + return false + } + if other.chainId == nil || s.chainId == nil { + return false + } + + return other.chainId.Cmp(s.chainId) == 0 } func (s ChainIdSigner) SignECDSA(tx *Transaction, prv *ecdsa.PrivateKey) (*Transaction, error) { diff --git a/core/types/transaction_signing_test.go b/core/types/transaction_signing_test.go index eaf1d345d..4a6f92f08 100644 --- a/core/types/transaction_signing_test.go +++ b/core/types/transaction_signing_test.go @@ -135,3 +135,28 @@ func TestCompatibleSign(t *testing.T) { t.Errorf("Incorrect pubkey for Basic Signer:\n%v\n%v", common.ToHex(pub), common.ToHex(pub_tx2)) } } + +func TestChainIdSigner_Equal(t *testing.T) { + + defaultChainID := big.NewInt(61) + + s := NewChainIdSigner(defaultChainID) + if s.chainId == nil || s.chainId.Cmp(new(big.Int)) == 0 || s.chainId.Cmp(big.NewInt(0)) == 0 || s.chainId.Cmp(defaultChainID) != 0 { + t.Errorf("unexpected: %v", s) + } + + s2Invalid0 := NewChainIdSigner(new(big.Int)) + if s.Equal(s2Invalid0) { + t.Errorf("unexpected: s: %v, s2: %v", s, s2Invalid0) + } + + s262 := NewChainIdSigner(big.NewInt(62)) + if s.Equal(s262) { + t.Errorf("unexpected: s: %v, s2: %v", s, s262) + } + + s261 := NewChainIdSigner(defaultChainID) + if !s.Equal(s261) { + t.Errorf("unexpected: s: %v, s2: %v", s, s261) + } +}