diff --git a/consensus/vbft/chain_store.go b/consensus/vbft/chain_store.go index ae0204b8..0a2818d7 100644 --- a/consensus/vbft/chain_store.go +++ b/consensus/vbft/chain_store.go @@ -20,6 +20,7 @@ package vbft import ( "fmt" + "sync" "github.com/ontio/ontology-eventbus/actor" "github.com/polynetwork/poly/common" @@ -36,6 +37,8 @@ type PendingBlock struct { hasSubmitted bool } type ChainStore struct { + mu *sync.Mutex + db *ledger.Ledger chainedBlockNum uint32 pendingBlocks map[uint32]*PendingBlock @@ -50,6 +53,7 @@ func OpenBlockStore(db *ledger.Ledger, serverPid *actor.PID) (*ChainStore, error pendingBlocks: make(map[uint32]*PendingBlock), pid: serverPid, needSubmitBlock: false, + mu: new(sync.Mutex), } merkleRoot, err := db.GetStateMerkleRoot(chainstore.chainedBlockNum) if err != nil { @@ -74,10 +78,16 @@ func (self *ChainStore) close() { } func (self *ChainStore) GetChainedBlockNum() uint32 { + self.mu.Lock() + defer self.mu.Unlock() + return self.chainedBlockNum } func (self *ChainStore) getExecMerkleRoot(blkNum uint32) (common.Uint256, error) { + self.mu.Lock() + defer self.mu.Unlock() + if blk, present := self.pendingBlocks[blkNum]; blk != nil && present { return blk.execResult.MerkleRoot, nil } @@ -92,6 +102,9 @@ func (self *ChainStore) getExecMerkleRoot(blkNum uint32) (common.Uint256, error) } func (self *ChainStore) getCrossStateRoot(blkNum uint32) (common.Uint256, error) { + self.mu.Lock() + defer self.mu.Unlock() + if blk, present := self.pendingBlocks[blkNum]; blk != nil && present { return blk.execResult.CrossStatesRoot, nil } @@ -104,6 +117,9 @@ func (self *ChainStore) getCrossStateRoot(blkNum uint32) (common.Uint256, error) } func (self *ChainStore) getExecWriteSet(blkNum uint32) *overlaydb.MemDB { + self.mu.Lock() + defer self.mu.Unlock() + if blk, present := self.pendingBlocks[blkNum]; blk != nil && present { return blk.execResult.WriteSet } @@ -111,6 +127,9 @@ func (self *ChainStore) getExecWriteSet(blkNum uint32) *overlaydb.MemDB { } func (self *ChainStore) ReloadFromLedger() { + self.mu.Lock() + defer self.mu.Unlock() + height := self.db.GetCurrentBlockHeight() if height > self.chainedBlockNum { // update chainstore height @@ -128,6 +147,9 @@ func (self *ChainStore) ReloadFromLedger() { } func (self *ChainStore) AddBlock(block *Block) error { + self.mu.Lock() + defer self.mu.Unlock() + if block == nil { return fmt.Errorf("try add nil block") } @@ -162,6 +184,9 @@ func (self *ChainStore) AddBlock(block *Block) error { } func (self *ChainStore) submitBlock(blkNum uint32) error { + self.mu.Lock() + defer self.mu.Unlock() + if blkNum == 0 { return nil } @@ -179,6 +204,9 @@ func (self *ChainStore) submitBlock(blkNum uint32) error { } func (self *ChainStore) getBlock(blockNum uint32) (*Block, error) { + self.mu.Lock() + defer self.mu.Unlock() + if blk, present := self.pendingBlocks[blockNum]; present { return blk.block, nil }