From 8e2426a2a49b6ef5b7cfcd37b50943a744db4137 Mon Sep 17 00:00:00 2001 From: Brian Stafford Date: Wed, 7 Aug 2024 08:20:43 -0500 Subject: [PATCH] end recovery on shutdown --- wallet/mock.go | 10 +++- wallet/wallet.go | 33 ++++++----- wallet/wallet_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 154 insertions(+), 15 deletions(-) diff --git a/wallet/mock.go b/wallet/mock.go index 43c3d881bb..602ffa5729 100644 --- a/wallet/mock.go +++ b/wallet/mock.go @@ -12,6 +12,9 @@ import ( ) type mockChainClient struct { + getBestBlockHeight int32 + getBlockHashFunc func() (*chainhash.Hash, error) + getBlockHeader *wire.BlockHeader } var _ chain.Interface = (*mockChainClient)(nil) @@ -26,7 +29,7 @@ func (m *mockChainClient) Stop() { func (m *mockChainClient) WaitForShutdown() {} func (m *mockChainClient) GetBestBlock() (*chainhash.Hash, int32, error) { - return nil, 0, nil + return nil, m.getBestBlockHeight, nil } func (m *mockChainClient) GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) { @@ -34,12 +37,15 @@ func (m *mockChainClient) GetBlock(*chainhash.Hash) (*wire.MsgBlock, error) { } func (m *mockChainClient) GetBlockHash(int64) (*chainhash.Hash, error) { + if m.getBlockHashFunc != nil { + return m.getBlockHashFunc() + } return nil, nil } func (m *mockChainClient) GetBlockHeader(*chainhash.Hash) (*wire.BlockHeader, error) { - return nil, nil + return m.getBlockHeader, nil } func (m *mockChainClient) IsCurrent() bool { diff --git a/wallet/wallet.go b/wallet/wallet.go index 96891f0c3e..d3bfe9811b 100644 --- a/wallet/wallet.go +++ b/wallet/wallet.go @@ -280,6 +280,8 @@ func (w *Wallet) quitChan() <-chan struct{} { // Stop signals all wallet goroutines to shutdown. func (w *Wallet) Stop() { + <-w.endRecovery() + w.quitMu.Lock() quit := w.quit w.quitMu.Unlock() @@ -1380,6 +1382,23 @@ type ( heldUnlock chan struct{} ) +// endRecovery tells (*Wallet).recovery to stop, if running, and returns a +// channel that will be closed when the recovery routine exits. +func (w *Wallet) endRecovery() <-chan struct{} { + if recoverySyncI := w.recovering.Load(); recoverySyncI != nil { + recoverySync := recoverySyncI.(*recoverySyncer) + + // If recovery is still running, it will end early with an error + // once we set the quit flag. + atomic.StoreUint32(&recoverySync.quit, 1) + + return recoverySync.done + } + c := make(chan struct{}) + close(c) + return c +} + // walletLocker manages the locked/unlocked state of a wallet. func (w *Wallet) walletLocker() { var timeout <-chan time.Time @@ -1472,19 +1491,7 @@ out: // We can't lock the manager if recovery is active because we use // cryptoKeyPriv and cryptoKeyScript in recovery. - if recoverySyncI := w.recovering.Load(); recoverySyncI != nil { - recoverySync := recoverySyncI.(*recoverySyncer) - // If recovery is still running, it will end early with an error - // once we set the quit flag. - atomic.StoreUint32(&recoverySync.quit, 1) - - select { - case <-recoverySync.done: - case <-quit: - break out - } - - } + <-w.endRecovery() timeout = nil err := w.Manager.Lock() diff --git a/wallet/wallet_test.go b/wallet/wallet_test.go index 1a06777a6d..dcd95c0133 100644 --- a/wallet/wallet_test.go +++ b/wallet/wallet_test.go @@ -3,12 +3,16 @@ package wallet import ( "encoding/hex" "fmt" + "math" + "strings" "sync" + "sync/atomic" "testing" "time" "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/waddrmgr" "github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/wtxmgr" @@ -359,3 +363,125 @@ func TestDuplicateAddressDerivation(t *testing.T) { require.NoError(t, eg.Wait()) } } + +func TestEndRecovery(t *testing.T) { + // This is an unconventional unit test, but I'm trying to keep things as + // succint as possible so that this test is readable without having to mock + // up literally everything. + // The unmonitored goroutine we're looking at is pretty deep: + // SynchronizeRPC -> handleChainNotifications -> syncWithChain -> recovery + // The "deadlock" we're addressing isn't actually a deadlock, but the wallet + // will hang on Stop() -> WaitForShutdown() until (*Wallet).recovery gets + // every single block, which could be hours depending on hardware and + // network factors. The WaitGroup is incremented in SynchronizeRPC, and + // WaitForShutdown will not return until handleChainNotifications returns, + // which is blocked by a running (*Wallet).recovery loop. + // It is noted that the conditions for long recovery are difficult to hit + // when using btcwallet with a fresh seed, because it requires an early + // birthday to be set or established. + + w, cleanup := testWallet(t) + + blockHashCalled := make(chan struct{}) + + chainClient := &mockChainClient{ + // Force the loop to iterate about forever. + getBestBlockHeight: math.MaxInt32, + // Get control of when the loop iterates. + getBlockHashFunc: func() (*chainhash.Hash, error) { + blockHashCalled <- struct{}{} + return &chainhash.Hash{}, nil + }, + // Avoid a panic. + getBlockHeader: &wire.BlockHeader{}, + } + + recoveryDone := make(chan struct{}) + go func() { + defer close(recoveryDone) + w.recovery(chainClient, &waddrmgr.BlockStamp{}) + }() + + getBlockHashCalls := func(expCalls int) { + var i int + for { + select { + case <-blockHashCalled: + i++ + case <-time.After(time.Second): + t.Fatal("expected BlockHash to be called") + } + if i == expCalls { + break + } + } + } + + // Recovery is running. + getBlockHashCalls(3) + + // Closing the quit channel, e.g. Stop() without endRecovery, alone will not + // end the recovery loop. + w.quitMu.Lock() + close(w.quit) + w.quitMu.Unlock() + // Continues scanning. + getBlockHashCalls(3) + + // We're done with this one + atomic.StoreUint32(&w.recovering.Load().(*recoverySyncer).quit, 1) + select { + case <-blockHashCalled: + case <-recoveryDone: + } + cleanup() + + // Try again. + w, cleanup = testWallet(t) + defer cleanup() + + // We'll catch the error to make sure we're hitting our desired path. The + // WaitGroup isn't required for the test, but does show how it completes + // shutdown at a higher level. + var err error + w.wg.Add(1) + recoveryDone = make(chan struct{}) + go func() { + defer w.wg.Done() + defer close(recoveryDone) + err = w.recovery(chainClient, &waddrmgr.BlockStamp{}) + }() + + waitedForShutdown := make(chan struct{}) + go func() { + w.WaitForShutdown() + close(waitedForShutdown) + }() + + // Recovery is running. + getBlockHashCalls(3) + + // endRecovery is required to exit the unmonitored goroutine. + end := w.endRecovery() + select { + case <-blockHashCalled: + case <-recoveryDone: + } + <-end + + // testWallet starts a couple of other unrelated goroutines that need to be + // killed, so we still need to close the quit channel. + w.quitMu.Lock() + close(w.quit) + w.quitMu.Unlock() + + select { + case <-waitedForShutdown: + case <-time.After(time.Second): + t.Fatal("WaitForShutdown never returned") + } + + if !strings.EqualFold(err.Error(), "recovery: forced shutdown") { + t.Fatal("wrong error") + } +}