From b457b40f807741e016a6ffbedf352b0b5766c3a6 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 16 Jul 2024 15:35:58 -0700 Subject: [PATCH 01/14] lnwallet: pack paymentDescriptor add/remove heights into Duals The purpose of this commit is to begin the process of packing symmetric fields into the newly introduced Dual structure. The reason for this is that the Dual structure has a handy indexing method where we can supply a ChannelParty and get back a value. This will cut down on the amount of branching code in the main lines of the codebase logic, making it easier to follow what is going on. --- lnwallet/aux_signer.go | 8 +- lnwallet/channel.go | 307 ++++++++++++++++++--------------- lnwallet/channel_test.go | 194 +++++++++++++-------- lnwallet/payment_descriptor.go | 7 +- lnwallet/update_log.go | 9 +- 5 files changed, 301 insertions(+), 224 deletions(-) diff --git a/lnwallet/aux_signer.go b/lnwallet/aux_signer.go index 5d4bc79241..01abe1aae3 100644 --- a/lnwallet/aux_signer.go +++ b/lnwallet/aux_signer.go @@ -109,10 +109,10 @@ func newAuxHtlcDescriptor(p *paymentDescriptor) AuxHtlcDescriptor { ParentIndex: p.ParentIndex, EntryType: p.EntryType, CustomRecords: p.CustomRecords.Copy(), - addCommitHeightRemote: p.addCommitHeightRemote, - addCommitHeightLocal: p.addCommitHeightLocal, - removeCommitHeightRemote: p.removeCommitHeightRemote, - removeCommitHeightLocal: p.removeCommitHeightLocal, + addCommitHeightRemote: p.addCommitHeights.Remote, + addCommitHeightLocal: p.addCommitHeights.Local, + removeCommitHeightRemote: p.removeCommitHeights.Remote, + removeCommitHeightLocal: p.removeCommitHeights.Local, } } diff --git a/lnwallet/channel.go b/lnwallet/channel.go index f3e0769506..4b0ce15191 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1080,17 +1080,19 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, // as we've included this HTLC in our local commitment chain // for the remote party. pd = &paymentDescriptor{ - ChanID: wireMsg.ChanID, - RHash: wireMsg.PaymentHash, - Timeout: wireMsg.Expiry, - Amount: wireMsg.Amount, - EntryType: Add, - HtlcIndex: wireMsg.ID, - LogIndex: logUpdate.LogIndex, - addCommitHeightRemote: commitHeight, - OnionBlob: wireMsg.OnionBlob, - BlindingPoint: wireMsg.BlindingPoint, - CustomRecords: wireMsg.CustomRecords.Copy(), + ChanID: wireMsg.ChanID, + RHash: wireMsg.PaymentHash, + Timeout: wireMsg.Expiry, + Amount: wireMsg.Amount, + EntryType: Add, + HtlcIndex: wireMsg.ID, + LogIndex: logUpdate.LogIndex, + OnionBlob: wireMsg.OnionBlob, + BlindingPoint: wireMsg.BlindingPoint, + CustomRecords: wireMsg.CustomRecords.Copy(), + addCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, } isDustRemote := HtlcIsDust( @@ -1125,14 +1127,16 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) pd = &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - RPreimage: wireMsg.PaymentPreimage, - LogIndex: logUpdate.LogIndex, - ParentIndex: ogHTLC.HtlcIndex, - EntryType: Settle, - removeCommitHeightRemote: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + RPreimage: wireMsg.PaymentPreimage, + LogIndex: logUpdate.LogIndex, + ParentIndex: ogHTLC.HtlcIndex, + EntryType: Settle, + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, } // If we sent a failure for a prior incoming HTLC, then we'll consult @@ -1143,14 +1147,16 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) pd = &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], - removeCommitHeightRemote: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, } // HTLC fails due to malformed onion blobs are treated the exact same @@ -1160,15 +1166,17 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, // TODO(roasbeef): err if nil? pd = &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: MalformedFail, - FailCode: wireMsg.FailureCode, - ShaOnionBlob: wireMsg.ShaOnionBlob, - removeCommitHeightRemote: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: MalformedFail, + FailCode: wireMsg.FailureCode, + ShaOnionBlob: wireMsg.ShaOnionBlob, + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, } // For fee updates we'll create a FeeUpdate type to add to the log. We @@ -1184,9 +1192,13 @@ func (lc *LightningChannel) logUpdateToPayDesc(logUpdate *channeldb.LogUpdate, Amount: lnwire.NewMSatFromSatoshis( btcutil.Amount(wireMsg.FeePerKw), ), - EntryType: FeeUpdate, - addCommitHeightRemote: commitHeight, - removeCommitHeightRemote: commitHeight, + EntryType: FeeUpdate, + addCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, } } @@ -1216,14 +1228,16 @@ func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpda ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - RPreimage: wireMsg.PaymentPreimage, - LogIndex: logUpdate.LogIndex, - ParentIndex: ogHTLC.HtlcIndex, - EntryType: Settle, - removeCommitHeightRemote: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + RPreimage: wireMsg.PaymentPreimage, + LogIndex: logUpdate.LogIndex, + ParentIndex: ogHTLC.HtlcIndex, + EntryType: Settle, + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, }, nil // If we sent a failure for a prior incoming HTLC, then we'll consult the @@ -1233,14 +1247,16 @@ func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpda ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], - removeCommitHeightRemote: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, }, nil // HTLC fails due to malformed onion blocks are treated the exact same @@ -1249,15 +1265,17 @@ func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpda ogHTLC := remoteUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: MalformedFail, - FailCode: wireMsg.FailureCode, - ShaOnionBlob: wireMsg.ShaOnionBlob, - removeCommitHeightRemote: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: MalformedFail, + FailCode: wireMsg.FailureCode, + ShaOnionBlob: wireMsg.ShaOnionBlob, + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, }, nil case *lnwire.UpdateFee: @@ -1267,9 +1285,13 @@ func (lc *LightningChannel) localLogUpdateToPayDesc(logUpdate *channeldb.LogUpda Amount: lnwire.NewMSatFromSatoshis( btcutil.Amount(wireMsg.FeePerKw), ), - EntryType: FeeUpdate, - addCommitHeightRemote: commitHeight, - removeCommitHeightRemote: commitHeight, + EntryType: FeeUpdate, + addCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, + removeCommitHeights: lntypes.Dual[uint64]{ + Remote: commitHeight, + }, }, nil default: @@ -1294,17 +1316,19 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd switch wireMsg := logUpdate.UpdateMsg.(type) { case *lnwire.UpdateAddHTLC: pd := &paymentDescriptor{ - ChanID: wireMsg.ChanID, - RHash: wireMsg.PaymentHash, - Timeout: wireMsg.Expiry, - Amount: wireMsg.Amount, - EntryType: Add, - HtlcIndex: wireMsg.ID, - LogIndex: logUpdate.LogIndex, - addCommitHeightLocal: commitHeight, - OnionBlob: wireMsg.OnionBlob, - BlindingPoint: wireMsg.BlindingPoint, - CustomRecords: wireMsg.CustomRecords.Copy(), + ChanID: wireMsg.ChanID, + RHash: wireMsg.PaymentHash, + Timeout: wireMsg.Expiry, + Amount: wireMsg.Amount, + EntryType: Add, + HtlcIndex: wireMsg.ID, + LogIndex: logUpdate.LogIndex, + OnionBlob: wireMsg.OnionBlob, + BlindingPoint: wireMsg.BlindingPoint, + CustomRecords: wireMsg.CustomRecords.Copy(), + addCommitHeights: lntypes.Dual[uint64]{ + Local: commitHeight, + }, } // We don't need to generate an htlc script yet. This will be @@ -1319,14 +1343,16 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - RPreimage: wireMsg.PaymentPreimage, - LogIndex: logUpdate.LogIndex, - ParentIndex: ogHTLC.HtlcIndex, - EntryType: Settle, - removeCommitHeightLocal: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + RPreimage: wireMsg.PaymentPreimage, + LogIndex: logUpdate.LogIndex, + ParentIndex: ogHTLC.HtlcIndex, + EntryType: Settle, + removeCommitHeights: lntypes.Dual[uint64]{ + Local: commitHeight, + }, }, nil // If we received a failure for a prior outgoing HTLC, then we'll @@ -1336,14 +1362,16 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: Fail, - FailReason: wireMsg.Reason[:], - removeCommitHeightLocal: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: Fail, + FailReason: wireMsg.Reason[:], + removeCommitHeights: lntypes.Dual[uint64]{ + Local: commitHeight, + }, }, nil // HTLC fails due to malformed onion blobs are treated the exact same @@ -1352,15 +1380,17 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd ogHTLC := localUpdateLog.lookupHtlc(wireMsg.ID) return &paymentDescriptor{ - ChanID: wireMsg.ChanID, - Amount: ogHTLC.Amount, - RHash: ogHTLC.RHash, - ParentIndex: ogHTLC.HtlcIndex, - LogIndex: logUpdate.LogIndex, - EntryType: MalformedFail, - FailCode: wireMsg.FailureCode, - ShaOnionBlob: wireMsg.ShaOnionBlob, - removeCommitHeightLocal: commitHeight, + ChanID: wireMsg.ChanID, + Amount: ogHTLC.Amount, + RHash: ogHTLC.RHash, + ParentIndex: ogHTLC.HtlcIndex, + LogIndex: logUpdate.LogIndex, + EntryType: MalformedFail, + FailCode: wireMsg.FailureCode, + ShaOnionBlob: wireMsg.ShaOnionBlob, + removeCommitHeights: lntypes.Dual[uint64]{ + Local: commitHeight, + }, }, nil // For fee updates we'll create a FeeUpdate type to add to the log. We @@ -1376,9 +1406,13 @@ func (lc *LightningChannel) remoteLogUpdateToPayDesc(logUpdate *channeldb.LogUpd Amount: lnwire.NewMSatFromSatoshis( btcutil.Amount(wireMsg.FeePerKw), ), - EntryType: FeeUpdate, - addCommitHeightLocal: commitHeight, - removeCommitHeightLocal: commitHeight, + EntryType: FeeUpdate, + addCommitHeights: lntypes.Dual[uint64]{ + Local: commitHeight, + }, + removeCommitHeights: lntypes.Dual[uint64]{ + Local: commitHeight, + }, }, nil default: @@ -1611,8 +1645,9 @@ func (lc *LightningChannel) restoreStateLogs( // map we created earlier. Note that if this HTLC is not in // incomingRemoteAddHeights, the remote add height will be set // to zero, which indicates that it is not added yet. - htlc.addCommitHeightLocal = localCommitment.height - htlc.addCommitHeightRemote = incomingRemoteAddHeights[htlc.HtlcIndex] + htlc.addCommitHeights.Local = localCommitment.height + htlc.addCommitHeights.Remote = + incomingRemoteAddHeights[htlc.HtlcIndex] // Restore the htlc back to the remote log. lc.updateLogs.Remote.restoreHtlc(&htlc) @@ -1626,8 +1661,9 @@ func (lc *LightningChannel) restoreStateLogs( // As for the incoming HTLCs, we'll use the current remote // commit height as remote add height, and consult the map // created above for the local add height. - htlc.addCommitHeightRemote = remoteCommitment.height - htlc.addCommitHeightLocal = outgoingLocalAddHeights[htlc.HtlcIndex] + htlc.addCommitHeights.Remote = remoteCommitment.height + htlc.addCommitHeights.Local = + outgoingLocalAddHeights[htlc.HtlcIndex] // Restore the htlc back to the local log. lc.updateLogs.Local.restoreHtlc(&htlc) @@ -1722,15 +1758,15 @@ func (lc *LightningChannel) restorePendingRemoteUpdates( switch payDesc.EntryType { case FeeUpdate: if heightSet { - payDesc.addCommitHeightRemote = height - payDesc.removeCommitHeightRemote = height + payDesc.addCommitHeights.Remote = height + payDesc.removeCommitHeights.Remote = height } lc.updateLogs.Remote.restoreUpdate(payDesc) default: if heightSet { - payDesc.removeCommitHeightRemote = height + payDesc.removeCommitHeights.Remote = height } lc.updateLogs.Remote.restoreUpdate(payDesc) @@ -2903,7 +2939,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // number of satoshis we've received within the channel. if mutateState && entry.EntryType == Settle && whoseCommitChain.IsLocal() && - entry.removeCommitHeightLocal == 0 { + entry.removeCommitHeights.Local == 0 { lc.channelState.TotalMSatReceived += entry.Amount } @@ -2943,7 +2979,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // channel. if mutateState && entry.EntryType == Settle && whoseCommitChain.IsLocal() && - entry.removeCommitHeightLocal == 0 { + entry.removeCommitHeights.Local == 0 { lc.channelState.TotalMSatSent += entry.Amount } @@ -3032,14 +3068,14 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, // The parent add height should never be zero at this point. If // that's the case we probably forgot to send a new commitment. case whoseCommitChain.IsRemote() && - addEntry.addCommitHeightRemote == 0: + addEntry.addCommitHeights.Remote == 0: return nil, fmt.Errorf("parent entry %d for update %d "+ "had zero remote add height", entry.ParentIndex, entry.LogIndex) case whoseCommitChain.IsLocal() && - addEntry.addCommitHeightLocal == 0: + addEntry.addCommitHeights.Local == 0: return nil, fmt.Errorf("parent entry %d for update %d "+ "had zero local add height", entry.ParentIndex, @@ -3063,9 +3099,9 @@ func processAddEntry(htlc *paymentDescriptor, ourBalance, // height w.r.t the local chain. var addHeight *uint64 if whoseCommitChain.IsRemote() { - addHeight = &htlc.addCommitHeightRemote + addHeight = &htlc.addCommitHeights.Remote } else { - addHeight = &htlc.addCommitHeightLocal + addHeight = &htlc.addCommitHeights.Local } if *addHeight != 0 { @@ -3097,9 +3133,9 @@ func processRemoveEntry(htlc *paymentDescriptor, ourBalance, var removeHeight *uint64 if whoseCommitChain.IsRemote() { - removeHeight = &htlc.removeCommitHeightRemote + removeHeight = &htlc.removeCommitHeights.Remote } else { - removeHeight = &htlc.removeCommitHeightLocal + removeHeight = &htlc.removeCommitHeights.Local } // Ignore any removal entries which have already been processed. @@ -3150,11 +3186,11 @@ func processFeeUpdate(feeUpdate *paymentDescriptor, nextHeight uint64, var addHeight *uint64 var removeHeight *uint64 if whoseCommitChain.IsRemote() { - addHeight = &feeUpdate.addCommitHeightRemote - removeHeight = &feeUpdate.removeCommitHeightRemote + addHeight = &feeUpdate.addCommitHeights.Remote + removeHeight = &feeUpdate.removeCommitHeights.Remote } else { - addHeight = &feeUpdate.addCommitHeightLocal - removeHeight = &feeUpdate.removeCommitHeightLocal + addHeight = &feeUpdate.addCommitHeights.Local + removeHeight = &feeUpdate.removeCommitHeights.Local } if *addHeight != 0 { @@ -3419,8 +3455,8 @@ func (lc *LightningChannel) createCommitDiff(newCommit *commitment, // If this entry wasn't committed at the exact height of this // remote commitment, then we'll skip it as it was already // lingering in the log. - if pd.addCommitHeightRemote != newCommit.height && - pd.removeCommitHeightRemote != newCommit.height { + if pd.addCommitHeights.Remote != newCommit.height && + pd.removeCommitHeights.Remote != newCommit.height { continue } @@ -5655,19 +5691,20 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // both of the remote and local heights are non-zero. If either // of these values is zero, it has yet to be committed in both // the local and remote chains. - committedAdd := pd.addCommitHeightRemote > 0 && - pd.addCommitHeightLocal > 0 - committedRmv := pd.removeCommitHeightRemote > 0 && - pd.removeCommitHeightLocal > 0 + committedAdd := pd.addCommitHeights.Remote > 0 && + pd.addCommitHeights.Local > 0 + committedRmv := pd.removeCommitHeights.Remote > 0 && + pd.removeCommitHeights.Local > 0 // Using the height of the remote and local commitments, // preemptively compute whether or not to forward this HTLC for // the case in which this in an Add HTLC, or if this is a // Settle, Fail, or MalformedFail. - shouldFwdAdd := remoteChainTail == pd.addCommitHeightRemote && - localChainTail >= pd.addCommitHeightLocal - shouldFwdRmv := remoteChainTail == pd.removeCommitHeightRemote && - localChainTail >= pd.removeCommitHeightLocal + shouldFwdAdd := remoteChainTail == pd.addCommitHeights.Remote && + localChainTail >= pd.addCommitHeights.Local + shouldFwdRmv := remoteChainTail == + pd.removeCommitHeights.Remote && + localChainTail >= pd.removeCommitHeights.Local // We'll only forward any new HTLC additions iff, it's "freshly // locked in". Meaning that the HTLC was only *just* considered diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 3adb21636d..82b5d55549 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -7614,13 +7614,13 @@ func TestChannelRestoreCommitHeight(t *testing.T) { t.Fatalf("htlc not found in log") } - if pd.addCommitHeightLocal != expLocal { + if pd.addCommitHeights.Local != expLocal { t.Fatalf("expected local add height to be %d, was %d", - expLocal, pd.addCommitHeightLocal) + expLocal, pd.addCommitHeights.Local) } - if pd.addCommitHeightRemote != expRemote { + if pd.addCommitHeights.Remote != expRemote { t.Fatalf("expected remote add height to be %d, was %d", - expRemote, pd.addCommitHeightRemote) + expRemote, pd.addCommitHeights.Remote) } return newChannel } @@ -8402,16 +8402,20 @@ func TestFetchParent(t *testing.T) { remoteEntries: []*paymentDescriptor{ // This entry will be added at log index =0. { - HtlcIndex: 1, - addCommitHeightLocal: 100, - addCommitHeightRemote: 100, + HtlcIndex: 1, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 100, + }, }, // This entry will be added at log index =1, it // is the parent entry we are looking for. { - HtlcIndex: 2, - addCommitHeightLocal: 100, - addCommitHeightRemote: 0, + HtlcIndex: 2, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 0, + }, }, }, whoseCommitChain: lntypes.Remote, @@ -8424,16 +8428,20 @@ func TestFetchParent(t *testing.T) { remoteEntries: []*paymentDescriptor{ // This entry will be added at log index =0. { - HtlcIndex: 1, - addCommitHeightLocal: 100, - addCommitHeightRemote: 100, + HtlcIndex: 1, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 100, + }, }, // This entry will be added at log index =1, it // is the parent entry we are looking for. { - HtlcIndex: 2, - addCommitHeightLocal: 0, - addCommitHeightRemote: 100, + HtlcIndex: 2, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 0, + Remote: 100, + }, }, }, localEntries: nil, @@ -8447,16 +8455,20 @@ func TestFetchParent(t *testing.T) { localEntries: []*paymentDescriptor{ // This entry will be added at log index =0. { - HtlcIndex: 1, - addCommitHeightLocal: 100, - addCommitHeightRemote: 100, + HtlcIndex: 1, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 100, + }, }, // This entry will be added at log index =1, it // is the parent entry we are looking for. { - HtlcIndex: 2, - addCommitHeightLocal: 0, - addCommitHeightRemote: 100, + HtlcIndex: 2, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 0, + Remote: 100, + }, }, }, remoteEntries: nil, @@ -8471,16 +8483,20 @@ func TestFetchParent(t *testing.T) { localEntries: []*paymentDescriptor{ // This entry will be added at log index =0. { - HtlcIndex: 1, - addCommitHeightLocal: 100, - addCommitHeightRemote: 100, + HtlcIndex: 1, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 100, + }, }, // This entry will be added at log index =1, it // is the parent entry we are looking for. { - HtlcIndex: 2, - addCommitHeightLocal: 100, - addCommitHeightRemote: 0, + HtlcIndex: 2, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 0, + }, }, }, remoteEntries: nil, @@ -8495,16 +8511,20 @@ func TestFetchParent(t *testing.T) { remoteEntries: []*paymentDescriptor{ // This entry will be added at log index =0. { - HtlcIndex: 1, - addCommitHeightLocal: 100, - addCommitHeightRemote: 0, + HtlcIndex: 1, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 0, + }, }, // This entry will be added at log index =1, it // is the parent entry we are looking for. { - HtlcIndex: 2, - addCommitHeightLocal: 100, - addCommitHeightRemote: 100, + HtlcIndex: 2, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 100, + }, }, }, whoseCommitChain: lntypes.Remote, @@ -8518,16 +8538,20 @@ func TestFetchParent(t *testing.T) { localEntries: []*paymentDescriptor{ // This entry will be added at log index =0. { - HtlcIndex: 1, - addCommitHeightLocal: 0, - addCommitHeightRemote: 100, + HtlcIndex: 1, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 0, + Remote: 100, + }, }, // This entry will be added at log index =1, it // is the parent entry we are looking for. { - HtlcIndex: 2, - addCommitHeightLocal: 100, - addCommitHeightRemote: 100, + HtlcIndex: 2, + addCommitHeights: lntypes.Dual[uint64]{ + Local: 100, + Remote: 100, + }, }, }, remoteEntries: nil, @@ -8728,10 +8752,12 @@ func TestEvaluateView(t *testing.T) { mutateState: true, ourHtlcs: []*paymentDescriptor{ { - HtlcIndex: 0, - Amount: htlcAddAmount, - EntryType: Add, - addCommitHeightLocal: addHeight, + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeights: lntypes.Dual[uint64]{ + Local: addHeight, + }, }, }, theirHtlcs: []*paymentDescriptor{ @@ -8763,10 +8789,12 @@ func TestEvaluateView(t *testing.T) { mutateState: false, ourHtlcs: []*paymentDescriptor{ { - HtlcIndex: 0, - Amount: htlcAddAmount, - EntryType: Add, - addCommitHeightLocal: addHeight, + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeights: lntypes.Dual[uint64]{ + Local: addHeight, + }, }, }, theirHtlcs: []*paymentDescriptor{ @@ -8813,16 +8841,20 @@ func TestEvaluateView(t *testing.T) { }, theirHtlcs: []*paymentDescriptor{ { - HtlcIndex: 0, - Amount: htlcAddAmount, - EntryType: Add, - addCommitHeightLocal: addHeight, + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeights: lntypes.Dual[uint64]{ + Local: addHeight, + }, }, { - HtlcIndex: 1, - Amount: htlcAddAmount, - EntryType: Add, - addCommitHeightLocal: addHeight, + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeights: lntypes.Dual[uint64]{ + Local: addHeight, + }, }, }, expectedFee: feePerKw, @@ -8857,10 +8889,12 @@ func TestEvaluateView(t *testing.T) { }, theirHtlcs: []*paymentDescriptor{ { - HtlcIndex: 0, - Amount: htlcAddAmount, - EntryType: Add, - addCommitHeightLocal: addHeight, + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeights: lntypes.Dual[uint64]{ + Local: addHeight, + }, }, }, expectedFee: feePerKw, @@ -9155,12 +9189,16 @@ func TestProcessFeeUpdate(t *testing.T) { // set in the test. heights := test.startHeights update := &paymentDescriptor{ - Amount: ourFeeUpdateAmt, - addCommitHeightRemote: heights.remoteAdd, - addCommitHeightLocal: heights.localAdd, - removeCommitHeightRemote: heights.remoteRemove, - removeCommitHeightLocal: heights.localRemove, - EntryType: FeeUpdate, + Amount: ourFeeUpdateAmt, + addCommitHeights: lntypes.Dual[uint64]{ + Local: heights.localAdd, + Remote: heights.remoteAdd, + }, + removeCommitHeights: lntypes.Dual[uint64]{ + Local: heights.localRemove, + Remote: heights.remoteRemove, + }, + EntryType: FeeUpdate, } view := &HtlcView{ @@ -9183,10 +9221,10 @@ func TestProcessFeeUpdate(t *testing.T) { func checkHeights(t *testing.T, update *paymentDescriptor, expected heights) { updateHeights := heights{ - localAdd: update.addCommitHeightLocal, - localRemove: update.removeCommitHeightLocal, - remoteAdd: update.addCommitHeightRemote, - remoteRemove: update.removeCommitHeightRemote, + localAdd: update.addCommitHeights.Local, + localRemove: update.removeCommitHeights.Local, + remoteAdd: update.addCommitHeights.Remote, + remoteRemove: update.removeCommitHeights.Remote, } if !reflect.DeepEqual(updateHeights, expected) { @@ -9551,12 +9589,16 @@ func TestProcessAddRemoveEntry(t *testing.T) { heights := test.startHeights update := &paymentDescriptor{ - Amount: updateAmount, - addCommitHeightLocal: heights.localAdd, - addCommitHeightRemote: heights.remoteAdd, - removeCommitHeightLocal: heights.localRemove, - removeCommitHeightRemote: heights.remoteRemove, - EntryType: test.updateType, + Amount: updateAmount, + addCommitHeights: lntypes.Dual[uint64]{ + Local: heights.localAdd, + Remote: heights.remoteAdd, + }, + removeCommitHeights: lntypes.Dual[uint64]{ + Local: heights.localRemove, + Remote: heights.remoteRemove, + }, + EntryType: test.updateType, } var ( diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index 5a51f29ce5..ffa4cc8ce1 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -6,6 +6,7 @@ import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" ) @@ -165,16 +166,14 @@ type paymentDescriptor struct { // which included this HTLC on either the remote or local commitment // chain. This value is used to determine when an HTLC is fully // "locked-in". - addCommitHeightRemote uint64 - addCommitHeightLocal uint64 + addCommitHeights lntypes.Dual[uint64] // removeCommitHeight[Remote|Local] encodes the height of the // commitment which removed the parent pointer of this // paymentDescriptor either due to a timeout or a settle. Once both // these heights are below the tail of both chains, the log entries can // safely be removed. - removeCommitHeightRemote uint64 - removeCommitHeightLocal uint64 + removeCommitHeights lntypes.Dual[uint64] // OnionBlob is an opaque blob which is used to complete multi-hop // routing. diff --git a/lnwallet/update_log.go b/lnwallet/update_log.go index 42e5373a22..2d1f65c9fa 100644 --- a/lnwallet/update_log.go +++ b/lnwallet/update_log.go @@ -153,6 +153,7 @@ func compactLogs(ourLog, theirLog *updateLog, nextA = e.Next() htlc := e.Value + rmvHeights := htlc.removeCommitHeights // We skip Adds, as they will be removed along with the // fail/settles below. @@ -162,9 +163,7 @@ func compactLogs(ourLog, theirLog *updateLog, // If the HTLC hasn't yet been removed from either // chain, the skip it. - if htlc.removeCommitHeightRemote == 0 || - htlc.removeCommitHeightLocal == 0 { - + if rmvHeights.Remote == 0 || rmvHeights.Local == 0 { continue } @@ -172,8 +171,8 @@ func compactLogs(ourLog, theirLog *updateLog, // is at least the height in which the HTLC was // removed, then evict the settle/timeout entry along // with the original add entry. - if remoteChainTail >= htlc.removeCommitHeightRemote && - localChainTail >= htlc.removeCommitHeightLocal { + if remoteChainTail >= rmvHeights.Remote && + localChainTail >= rmvHeights.Local { // Fee updates have no parent htlcs, so we only // remove the update itself. From 71da6b5336b81b29a5d2a9a92e8a3f5cccb180e3 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 16 Jul 2024 15:47:14 -0700 Subject: [PATCH 02/14] lnwallet: consolidate redundant cases using Dual.ForParty --- lnwallet/channel.go | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 4b0ce15191..6b6799227e 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -3067,19 +3067,10 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, // The parent add height should never be zero at this point. If // that's the case we probably forgot to send a new commitment. - case whoseCommitChain.IsRemote() && - addEntry.addCommitHeights.Remote == 0: - - return nil, fmt.Errorf("parent entry %d for update %d "+ - "had zero remote add height", entry.ParentIndex, - entry.LogIndex) - - case whoseCommitChain.IsLocal() && - addEntry.addCommitHeights.Local == 0: - + case addEntry.addCommitHeights.GetForParty(whoseCommitChain) == 0: return nil, fmt.Errorf("parent entry %d for update %d "+ - "had zero local add height", entry.ParentIndex, - entry.LogIndex) + "had zero %v add height", entry.ParentIndex, + entry.LogIndex, whoseCommitChain) } return addEntry, nil From 49add0f57b3af7b47a800caa143a644499d0f0a1 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 17 Jul 2024 13:09:04 -0700 Subject: [PATCH 03/14] lnwallet: eliminate inner-most layer of evil mutateState nonsense This commit begins the process of moving towards a more principled means of state tracking. We eliminate the mutateState argument from processAddEntry and processRemoveEntry and move the responsibility of mutating said state to the call-sites. The current call-sites of these functions still have their *own* mutateState argument which will be eliminated during upcoming commits. However, following the principle of micro-commits I opted to break these changes up to make review simpler. --- lnwallet/channel.go | 119 +++++++++++++++++++++------------------ lnwallet/channel_test.go | 20 +++++-- 2 files changed, 79 insertions(+), 60 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 6b6799227e..13a421458b 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2953,10 +2953,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, skipThem[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry( - entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, true, mutateState, - ) + rmvHeights := &entry.removeCommitHeights + rmvHeight := rmvHeights.GetForParty(whoseCommitChain) + if rmvHeight == 0 { + processRemoveEntry( + entry, ourBalance, theirBalance, true, + ) + + if mutateState { + rmvHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } + } } for _, entry := range view.TheirUpdates { switch entry.EntryType { @@ -2993,10 +3002,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, skipUs[addEntry.HtlcIndex] = struct{}{} - processRemoveEntry( - entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, false, mutateState, - ) + rmvHeights := &entry.removeCommitHeights + rmvHeight := rmvHeights.GetForParty(whoseCommitChain) + if rmvHeight == 0 { + processRemoveEntry( + entry, ourBalance, theirBalance, false, + ) + + if mutateState { + rmvHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } + } } // Next we take a second pass through all the log entries, skipping any @@ -3008,10 +3026,24 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, continue } - processAddEntry( - entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, false, mutateState, - ) + // Skip the entries that have already had their add commit + // height set for this commit chain. + addHeights := &entry.addCommitHeights + addHeight := addHeights.GetForParty(whoseCommitChain) + if addHeight == 0 { + processAddEntry( + entry, ourBalance, theirBalance, false, + ) + + // If we are mutating the state, then set the add + // height for the appropriate commitment chain to the + // next height. + if mutateState { + addHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } + } newView.OurUpdates = append(newView.OurUpdates, entry) } @@ -3021,10 +3053,24 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, continue } - processAddEntry( - entry, ourBalance, theirBalance, nextHeight, - whoseCommitChain, true, mutateState, - ) + // Skip the entries that have already had their add commit + // height set for this commit chain. + addHeights := &entry.addCommitHeights + addHeight := addHeights.GetForParty(whoseCommitChain) + if addHeight == 0 { + processAddEntry( + entry, ourBalance, theirBalance, true, + ) + + // If we are mutating the state, then set the add + // height for the appropriate commitment chain to the + // next height. + if mutateState { + addHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } + } newView.TheirUpdates = append(newView.TheirUpdates, entry) } @@ -3081,23 +3127,7 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, // was committed is updated. Keeping track of this inclusion height allows us to // later compact the log once the change is fully committed in both chains. func processAddEntry(htlc *paymentDescriptor, ourBalance, - theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, isIncoming, mutateState bool) { - - // If we're evaluating this entry for the remote chain (to create/view - // a new commitment), then we'll may be updating the height this entry - // was added to the chain. Otherwise, we may be updating the entry's - // height w.r.t the local chain. - var addHeight *uint64 - if whoseCommitChain.IsRemote() { - addHeight = &htlc.addCommitHeights.Remote - } else { - addHeight = &htlc.addCommitHeights.Local - } - - if *addHeight != 0 { - return - } + theirBalance *lnwire.MilliSatoshi, isIncoming bool) { if isIncoming { // If this is a new incoming (un-committed) HTLC, then we need @@ -3109,30 +3139,13 @@ func processAddEntry(htlc *paymentDescriptor, ourBalance, // going HTLC to reflect the pending balance. *ourBalance -= htlc.Amount } - - if mutateState { - *addHeight = nextHeight - } } // processRemoveEntry processes a log entry which settles or times out a // previously added HTLC. If the removal entry has already been processed, it // is skipped. func processRemoveEntry(htlc *paymentDescriptor, ourBalance, - theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, isIncoming, mutateState bool) { - - var removeHeight *uint64 - if whoseCommitChain.IsRemote() { - removeHeight = &htlc.removeCommitHeights.Remote - } else { - removeHeight = &htlc.removeCommitHeights.Local - } - - // Ignore any removal entries which have already been processed. - if *removeHeight != 0 { - return - } + theirBalance *lnwire.MilliSatoshi, isIncoming bool) { switch { // If an incoming HTLC is being settled, then this means that we've @@ -3159,10 +3172,6 @@ func processRemoveEntry(htlc *paymentDescriptor, ourBalance, case !isIncoming && (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): *ourBalance += htlc.Amount } - - if mutateState { - *removeHeight = nextHeight - } } // processFeeUpdate processes a log update that updates the current commitment diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 82b5d55549..997f0c22d5 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9613,15 +9613,25 @@ func TestProcessAddRemoveEntry(t *testing.T) { // update type. Process remove is used for settles, // fails and malformed htlcs. process := processRemoveEntry + heightDual := &update.removeCommitHeights if test.updateType == Add { process = processAddEntry + heightDual = &update.addCommitHeights } - process( - update, &ourBalance, &theirBalance, nextHeight, - test.whoseCommitChain, test.isIncoming, - test.mutateState, - ) + if heightDual.GetForParty(test.whoseCommitChain) == 0 { + process( + update, &ourBalance, &theirBalance, + test.isIncoming, + ) + + if test.mutateState { + heightDual.SetForParty( + test.whoseCommitChain, + nextHeight, + ) + } + } // Check that balances were updated as expected. if ourBalance != test.ourExpectedBalance { From 05347c839263554e609a5275cbe4fc99628f46ec Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 17 Jul 2024 13:35:32 -0700 Subject: [PATCH 04/14] lnwallet: bring processFeeUpdate in line with process[Add|Remove]Entry This commit redoes the API and semantics of processFeeUpdate to make it consistent with the semantics of it's sister functions. This is part of an ongoing series of commits to remove mutateState arguments pervasively from the codebase. As with the previous commit this makes state mutation the caller's responsibility. This temporarily increases code duplication at the call-sites, but this will unlock other refactor opportunities. --- lnwallet/channel.go | 74 +++++++++++++++++++++++----------------- lnwallet/channel_test.go | 25 ++++++++++++-- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 13a421458b..2157afe031 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2927,10 +2927,26 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Process fee updates, updating the current feePerKw. case FeeUpdate: - processFeeUpdate( - entry, nextHeight, whoseCommitChain, - mutateState, newView, + h := entry.addCommitHeights.GetForParty( + whoseCommitChain, ) + + if h == 0 { + processFeeUpdate( + entry, &newView.FeePerKw, nextHeight, + whoseCommitChain, + ) + + if mutateState { + entry.addCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + + entry.removeCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } + } continue } @@ -2975,10 +2991,26 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Process fee updates, updating the current feePerKw. case FeeUpdate: - processFeeUpdate( - entry, nextHeight, whoseCommitChain, - mutateState, newView, + h := entry.addCommitHeights.GetForParty( + whoseCommitChain, ) + + if h == 0 { + processFeeUpdate( + entry, &newView.FeePerKw, nextHeight, + whoseCommitChain, + ) + + if mutateState { + entry.addCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + + entry.removeCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } + } continue } @@ -3176,35 +3208,13 @@ func processRemoveEntry(htlc *paymentDescriptor, ourBalance, // processFeeUpdate processes a log update that updates the current commitment // fee. -func processFeeUpdate(feeUpdate *paymentDescriptor, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool, - view *HtlcView) { - - // Fee updates are applied for all commitments after they are - // sent/received, so we consider them being added and removed at the - // same height. - var addHeight *uint64 - var removeHeight *uint64 - if whoseCommitChain.IsRemote() { - addHeight = &feeUpdate.addCommitHeights.Remote - removeHeight = &feeUpdate.removeCommitHeights.Remote - } else { - addHeight = &feeUpdate.addCommitHeights.Local - removeHeight = &feeUpdate.removeCommitHeights.Local - } - - if *addHeight != 0 { - return - } +func processFeeUpdate(feeUpdate *paymentDescriptor, + feeRef *chainfee.SatPerKWeight, nextHeight uint64, + whoseCommitChain lntypes.ChannelParty) { // If the update wasn't already locked in, update the current fee rate // to reflect this update. - view.FeePerKw = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) - - if mutateState { - *addHeight = nextHeight - *removeHeight = nextHeight - } + *feeRef = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) } // generateRemoteHtlcSigJobs generates a series of HTLC signature jobs for the diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 997f0c22d5..77630d57aa 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9204,11 +9204,30 @@ func TestProcessFeeUpdate(t *testing.T) { view := &HtlcView{ FeePerKw: chainfee.SatPerKWeight(feePerKw), } - processFeeUpdate( - update, nextHeight, test.whoseCommitChain, - test.mutate, view, + + h := update.addCommitHeights.GetForParty( + test.whoseCommitChain, ) + if h == 0 { + processFeeUpdate( + update, &view.FeePerKw, nextHeight, + test.whoseCommitChain, + ) + + if test.mutate { + update.addCommitHeights.SetForParty( + test.whoseCommitChain, + nextHeight, + ) + + update.removeCommitHeights.SetForParty( + test.whoseCommitChain, + nextHeight, + ) + } + } + if view.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, feePerKw) From 819239c5c86201180fe8e49ae0c9dd386893d709 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Thu, 18 Jul 2024 17:14:13 -0700 Subject: [PATCH 05/14] lnwallet: inline processUpdateFee and remove the function entirely In this commit we observe that the previous commit reduced the role of this function to a single assignment statement with numerous newly irrelevant parameters. This commit makes the choice of inlining it at the two call-sites within evaluateHTLCView and removing the funciton definition entirely. This also allows us to drop a huge portion of newly irrelevant test code. --- lnwallet/channel.go | 28 ++--- lnwallet/channel_test.go | 217 --------------------------------------- 2 files changed, 11 insertions(+), 234 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 2157afe031..6b6f444313 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2932,9 +2932,11 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) if h == 0 { - processFeeUpdate( - entry, &newView.FeePerKw, nextHeight, - whoseCommitChain, + // If the update wasn't already locked in, + // update the current fee rate to reflect this + // update. + newView.FeePerKw = chainfee.SatPerKWeight( + entry.Amount.ToSatoshis(), ) if mutateState { @@ -2996,11 +2998,14 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) if h == 0 { - processFeeUpdate( - entry, &newView.FeePerKw, nextHeight, - whoseCommitChain, + // If the update wasn't already locked in, + // update the current fee rate to reflect this + // update. + newView.FeePerKw = chainfee.SatPerKWeight( + entry.Amount.ToSatoshis(), ) + if mutateState { entry.addCommitHeights.SetForParty( whoseCommitChain, nextHeight, @@ -3206,17 +3211,6 @@ func processRemoveEntry(htlc *paymentDescriptor, ourBalance, } } -// processFeeUpdate processes a log update that updates the current commitment -// fee. -func processFeeUpdate(feeUpdate *paymentDescriptor, - feeRef *chainfee.SatPerKWeight, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty) { - - // If the update wasn't already locked in, update the current fee rate - // to reflect this update. - *feeRef = chainfee.SatPerKWeight(feeUpdate.Amount.ToSatoshis()) -} - // generateRemoteHtlcSigJobs generates a series of HTLC signature jobs for the // sig pool, along with a channel that if closed, will cancel any jobs after // they have been submitted to the sigPool. This method is to be used when diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 77630d57aa..50eb24663e 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9021,223 +9021,6 @@ type heights struct { remoteRemove uint64 } -// TestProcessFeeUpdate tests the applying of fee updates and mutation of -// local and remote add and remove heights on update messages. -func TestProcessFeeUpdate(t *testing.T) { - const ( - // height is a non-zero height that can be used for htlcs - // heights. - height = 200 - - // nextHeight is a constant that we use for the next height in - // all unit tests. - nextHeight = 400 - - // feePerKw is the fee we start all of our unit tests with. - feePerKw = 1 - - // ourFeeUpdateAmt is an amount that we update fees to expressed - // in msat. - ourFeeUpdateAmt = 20000 - - // ourFeeUpdatePerSat is the fee rate *in satoshis* that we - // expect if we update to ourFeeUpdateAmt. - ourFeeUpdatePerSat = chainfee.SatPerKWeight(20) - ) - - tests := []struct { - name string - startHeights heights - expectedHeights heights - whoseCommitChain lntypes.ChannelParty - mutate bool - expectedFee chainfee.SatPerKWeight - }{ - { - // Looking at local chain, local add is non-zero so - // the update has been applied already; no fee change. - name: "non-zero local height, fee unchanged", - startHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: 0, - remoteRemove: height, - }, - expectedHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: 0, - remoteRemove: height, - }, - whoseCommitChain: lntypes.Local, - mutate: false, - expectedFee: feePerKw, - }, - { - // Looking at local chain, local add is zero so the - // update has not been applied yet; we expect a fee - // update. - name: "zero local height, fee changed", - startHeights: heights{ - localAdd: 0, - localRemove: 0, - remoteAdd: height, - remoteRemove: 0, - }, - expectedHeights: heights{ - localAdd: 0, - localRemove: 0, - remoteAdd: height, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - mutate: false, - expectedFee: ourFeeUpdatePerSat, - }, - { - // Looking at remote chain, the remote add height is - // zero, so the update has not been applied so we expect - // a fee change. - name: "zero remote height, fee changed", - startHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: 0, - remoteRemove: 0, - }, - expectedHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - mutate: false, - expectedFee: ourFeeUpdatePerSat, - }, - { - // Looking at remote chain, the remote add height is - // non-zero, so the update has been applied so we expect - // no fee change. - name: "non-zero remote height, no fee change", - startHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: height, - remoteRemove: 0, - }, - expectedHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: height, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - mutate: false, - expectedFee: feePerKw, - }, - { - // Local add height is non-zero, so the update has - // already been applied; we do not expect fee to - // change or any mutations to be applied. - name: "non-zero local height, mutation not applied", - startHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: 0, - remoteRemove: height, - }, - expectedHeights: heights{ - localAdd: height, - localRemove: 0, - remoteAdd: 0, - remoteRemove: height, - }, - whoseCommitChain: lntypes.Local, - mutate: true, - expectedFee: feePerKw, - }, - { - // Local add is zero and we are looking at our local - // chain, so the update has not been applied yet. We - // expect the local add and remote heights to be - // mutated. - name: "zero height, fee changed, mutation applied", - startHeights: heights{ - localAdd: 0, - localRemove: 0, - remoteAdd: 0, - remoteRemove: 0, - }, - expectedHeights: heights{ - localAdd: nextHeight, - localRemove: nextHeight, - remoteAdd: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - mutate: true, - expectedFee: ourFeeUpdatePerSat, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - // Create a fee update with add and remove heights as - // set in the test. - heights := test.startHeights - update := &paymentDescriptor{ - Amount: ourFeeUpdateAmt, - addCommitHeights: lntypes.Dual[uint64]{ - Local: heights.localAdd, - Remote: heights.remoteAdd, - }, - removeCommitHeights: lntypes.Dual[uint64]{ - Local: heights.localRemove, - Remote: heights.remoteRemove, - }, - EntryType: FeeUpdate, - } - - view := &HtlcView{ - FeePerKw: chainfee.SatPerKWeight(feePerKw), - } - - h := update.addCommitHeights.GetForParty( - test.whoseCommitChain, - ) - - if h == 0 { - processFeeUpdate( - update, &view.FeePerKw, nextHeight, - test.whoseCommitChain, - ) - - if test.mutate { - update.addCommitHeights.SetForParty( - test.whoseCommitChain, - nextHeight, - ) - - update.removeCommitHeights.SetForParty( - test.whoseCommitChain, - nextHeight, - ) - } - } - - if view.FeePerKw != test.expectedFee { - t.Fatalf("expected fee: %v, got: %v", - test.expectedFee, feePerKw) - } - - checkHeights(t, update, test.expectedHeights) - }) - } -} - func checkHeights(t *testing.T, update *paymentDescriptor, expected heights) { updateHeights := heights{ localAdd: update.addCommitHeights.Local, From d82d02831d92f26007f07633ebdab10a24e3710e Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Fri, 19 Jul 2024 16:53:58 -0700 Subject: [PATCH 06/14] lnwallet: remove mutateState from evaluateHTLCView In line with previous commits we are progressively removing the mutateState argument from this call stack for a more principled software design approach. NOTE FOR REVIEWERS: We take a naive approach to updating the tests here and simply take the functionality we are removing from evaluateHTLCView and run it directly after the function in the test suite. It's possible that we should instead remove this from the test suite altogether but I opted to take a more conservative approach with respect to reducing the scope of tests. If you have opinions here, please make them known. --- lntypes/channel_party.go | 2 + lnwallet/channel.go | 183 ++++++++++++++++----------------- lnwallet/channel_test.go | 30 +++++- lnwallet/payment_descriptor.go | 28 +++++ 4 files changed, 146 insertions(+), 97 deletions(-) diff --git a/lntypes/channel_party.go b/lntypes/channel_party.go index 5848becee6..82cbd1045e 100644 --- a/lntypes/channel_party.go +++ b/lntypes/channel_party.go @@ -117,3 +117,5 @@ func MapDual[A, B any](d Dual[A], f func(A) B) Dual[B] { Remote: f(d.Remote), } } + +var BothParties []ChannelParty = []ChannelParty{Local, Remote} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 6b6f444313..8af684c794 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2890,16 +2890,14 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // returned reflects the current state of HTLCs within the remote or local // commitment chain, and the current commitment fee rate. // -// If mutateState is set to true, then the add height of all added HTLCs -// will be set to nextHeight, and the remove height of all removed HTLCs -// will be set to nextHeight. This should therefore only be set to true -// once for each height, and only in concert with signing a new commitment. -// TODO(halseth): return htlcs to mutate instead of mutating inside -// method. +// The return values of this function are as follows: +// 1. The new htlcView reflecting the current channel state. +// 2. A Dual of the updates which have not yet been committed in +// 'whoseCommitChain's commitment chain. func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty, mutateState bool) (*HtlcView, - error) { + whoseCommitChain lntypes.ChannelParty) (*HtlcView, + lntypes.Dual[[]*paymentDescriptor], error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will @@ -2917,8 +2915,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, skipThem := make(map[uint64]struct{}) // First we run through non-add entries in both logs, populating the - // skip sets and mutating the current chain state (crediting balances, - // etc) to reflect the settle/timeout entry encountered. + // skip sets. for _, entry := range view.OurUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. @@ -2938,53 +2935,31 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, newView.FeePerKw = chainfee.SatPerKWeight( entry.Amount.ToSatoshis(), ) - - if mutateState { - entry.addCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - - entry.removeCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } - continue - } - - // If we're settling an inbound HTLC, and it hasn't been - // processed yet, then increment our state tracking the total - // number of satoshis we've received within the channel. - if mutateState && entry.EntryType == Settle && - whoseCommitChain.IsLocal() && - entry.removeCommitHeights.Local == 0 { - lc.channelState.TotalMSatReceived += entry.Amount + continue } addEntry, err := lc.fetchParent( entry, whoseCommitChain, lntypes.Remote, ) if err != nil { - return nil, err + return nil, lntypes.Dual[[]*paymentDescriptor]{}, err } skipThem[addEntry.HtlcIndex] = struct{}{} - rmvHeights := &entry.removeCommitHeights - rmvHeight := rmvHeights.GetForParty(whoseCommitChain) + rmvHeight := entry.removeCommitHeights.GetForParty( + whoseCommitChain, + ) if rmvHeight == 0 { processRemoveEntry( entry, ourBalance, theirBalance, true, ) - - if mutateState { - rmvHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } } + + // Do the same for our peer's updates. for _, entry := range view.TheirUpdates { switch entry.EntryType { // Skip adds for now. They will be processed below. @@ -3004,53 +2979,27 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, newView.FeePerKw = chainfee.SatPerKWeight( entry.Amount.ToSatoshis(), ) - - - if mutateState { - entry.addCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - - entry.removeCommitHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } - continue - } - // If the remote party is settling one of our outbound HTLC's, - // and it hasn't been processed, yet, the increment our state - // tracking the total number of satoshis we've sent within the - // channel. - if mutateState && entry.EntryType == Settle && - whoseCommitChain.IsLocal() && - entry.removeCommitHeights.Local == 0 { - - lc.channelState.TotalMSatSent += entry.Amount + continue } addEntry, err := lc.fetchParent( entry, whoseCommitChain, lntypes.Local, ) if err != nil { - return nil, err + return nil, lntypes.Dual[[]*paymentDescriptor]{}, err } skipUs[addEntry.HtlcIndex] = struct{}{} - rmvHeights := &entry.removeCommitHeights - rmvHeight := rmvHeights.GetForParty(whoseCommitChain) + rmvHeight := entry.removeCommitHeights.GetForParty( + whoseCommitChain, + ) if rmvHeight == 0 { processRemoveEntry( entry, ourBalance, theirBalance, false, ) - - if mutateState { - rmvHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } } @@ -3065,25 +3014,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Skip the entries that have already had their add commit // height set for this commit chain. - addHeights := &entry.addCommitHeights - addHeight := addHeights.GetForParty(whoseCommitChain) + addHeight := entry.addCommitHeights.GetForParty( + whoseCommitChain, + ) if addHeight == 0 { processAddEntry( entry, ourBalance, theirBalance, false, ) - - // If we are mutating the state, then set the add - // height for the appropriate commitment chain to the - // next height. - if mutateState { - addHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } newView.OurUpdates = append(newView.OurUpdates, entry) } + + // Again, we do the same for our peer's updates. for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { @@ -3092,27 +3035,51 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Skip the entries that have already had their add commit // height set for this commit chain. - addHeights := &entry.addCommitHeights - addHeight := addHeights.GetForParty(whoseCommitChain) + addHeight := entry.addCommitHeights.GetForParty( + whoseCommitChain, + ) if addHeight == 0 { processAddEntry( entry, ourBalance, theirBalance, true, ) - - // If we are mutating the state, then set the add - // height for the appropriate commitment chain to the - // next height. - if mutateState { - addHeights.SetForParty( - whoseCommitChain, nextHeight, - ) - } } newView.TheirUpdates = append(newView.TheirUpdates, entry) } - return newView, nil + // Create a function that is capable of identifying whether or not the + // paymentDescriptor has been committed in the commitment chain + // corresponding to whoseCommitmentChain. + isUncommitted := func(update *paymentDescriptor) bool { + switch update.EntryType { + case Add: + return update.addCommitHeights.GetForParty( + whoseCommitChain, + ) == 0 + + case FeeUpdate: + return update.addCommitHeights.GetForParty( + whoseCommitChain, + ) == 0 + + case Settle, Fail, MalformedFail: + return update.removeCommitHeights.GetForParty( + whoseCommitChain, + ) == 0 + + default: + panic("invalid paymentDescriptor EntryType") + } + } + + // Collect all of the updates that haven't had their commit heights set + // for the commitment chain corresponding to whoseCommitmentChain. + uncommittedUpdates := lntypes.Dual[[]*paymentDescriptor]{ + Local: fn.Filter(isUncommitted, view.OurUpdates), + Remote: fn.Filter(isUncommitted, view.TheirUpdates), + } + + return newView, uncommittedUpdates, nil } // fetchParent is a helper that looks up update log parent entries in the @@ -4683,13 +4650,27 @@ func (lc *LightningChannel) computeView(view *HtlcView, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, err := lc.evaluateHTLCView( + filteredHTLCView, uncommitted, err := lc.evaluateHTLCView( view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain, - updateState, ) if err != nil { return 0, 0, 0, nil, err } + + if updateState { + for _, party := range lntypes.BothParties { + for _, u := range uncommitted.GetForParty(party) { + u.setCommitHeight(whoseCommitChain, nextHeight) + + if whoseCommitChain == lntypes.Local && + u.EntryType == Settle { + + lc.recordSettlement(party, u.Amount) + } + } + } + } + feePerKw := filteredHTLCView.FeePerKw // Here we override the view's fee-rate if a dry-run fee-rate was @@ -4742,6 +4723,18 @@ func (lc *LightningChannel) computeView(view *HtlcView, return ourBalance, theirBalance, totalCommitWeight, filteredHTLCView, nil } +// recordSettlement updates the lifetime payment flow values in persistent state +// of the LightningChannel, adding amt to the total received by the redeemer. +func (lc *LightningChannel) recordSettlement( + redeemer lntypes.ChannelParty, amt lnwire.MilliSatoshi) { + + if redeemer == lntypes.Local { + lc.channelState.TotalMSatReceived += amt + } else { + lc.channelState.TotalMSatSent += amt + } +} + // genHtlcSigValidationJobs generates a series of signatures verification jobs // meant to verify all the signatures for HTLC's attached to a newly created // commitment state. The jobs generated are fully populated, and can be sent diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 50eb24663e..0372e4d166 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8956,14 +8956,40 @@ func TestEvaluateView(t *testing.T) { ) // Evaluate the htlc view, mutate as test expects. - result, err := lc.evaluateHTLCView( + result, uncommitted, err := lc.evaluateHTLCView( view, &ourBalance, &theirBalance, nextHeight, - test.whoseCommitChain, test.mutateState, + test.whoseCommitChain, ) + if err != nil { t.Fatalf("unexpected error: %v", err) } + // TODO(proofofkeags): This block is here because we + // extracted this code from a previous implementation + // of evaluateHTLCView, due to a reduced scope of + // responsibility of that function. Consider removing + // it from the test altogether. + if test.mutateState { + for _, party := range lntypes.BothParties { + us := uncommitted.GetForParty(party) + for _, u := range us { + u.setCommitHeight( + test.whoseCommitChain, + nextHeight, + ) + if test.whoseCommitChain == + lntypes.Local && + u.EntryType == Settle { + + lc.recordSettlement( + party, u.Amount, + ) + } + } + } + } + if result.FeePerKw != test.expectedFee { t.Fatalf("expected fee: %v, got: %v", test.expectedFee, result.FeePerKw) diff --git a/lnwallet/payment_descriptor.go b/lnwallet/payment_descriptor.go index ffa4cc8ce1..a8edb1e7e6 100644 --- a/lnwallet/payment_descriptor.go +++ b/lnwallet/payment_descriptor.go @@ -283,3 +283,31 @@ func (pd *paymentDescriptor) toLogUpdate() channeldb.LogUpdate { UpdateMsg: msg, } } + +// setCommitHeight updates the appropriate addCommitHeight and/or +// removeCommitHeight for whoseCommitChain and locks it in at nextHeight. +func (pd *paymentDescriptor) setCommitHeight( + whoseCommitChain lntypes.ChannelParty, nextHeight uint64) { + + switch pd.EntryType { + case Add: + pd.addCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + case Settle, Fail, MalformedFail: + pd.removeCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + case FeeUpdate: + // Fee updates are applied for all commitments + // after they are sent/received, so we consider + // them being added and removed at the same + // height. + pd.addCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + pd.removeCommitHeights.SetForParty( + whoseCommitChain, nextHeight, + ) + } +} From b902d0825d39c941c299cc3ef38b9d27212ecbbb Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Mon, 22 Jul 2024 14:12:03 -0700 Subject: [PATCH 07/14] lnwallet: use fn.Set API directly instead of empty struct map. --- lnwallet/channel.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 8af684c794..5e6fbc5b5f 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2911,8 +2911,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // keep track of which entries we need to skip when creating the final // htlc view. We skip an entry whenever we find a settle or a timeout // modifying an entry. - skipUs := make(map[uint64]struct{}) - skipThem := make(map[uint64]struct{}) + skipUs := fn.NewSet[uint64]() + skipThem := fn.NewSet[uint64]() // First we run through non-add entries in both logs, populating the // skip sets. @@ -2947,7 +2947,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, return nil, lntypes.Dual[[]*paymentDescriptor]{}, err } - skipThem[addEntry.HtlcIndex] = struct{}{} + skipThem.Add(addEntry.HtlcIndex) rmvHeight := entry.removeCommitHeights.GetForParty( whoseCommitChain, @@ -2991,7 +2991,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, return nil, lntypes.Dual[[]*paymentDescriptor]{}, err } - skipUs[addEntry.HtlcIndex] = struct{}{} + skipUs.Add(addEntry.HtlcIndex) rmvHeight := entry.removeCommitHeights.GetForParty( whoseCommitChain, @@ -3008,7 +3008,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // added HTLCs. for _, entry := range view.OurUpdates { isAdd := entry.EntryType == Add - if _, ok := skipUs[entry.HtlcIndex]; !isAdd || ok { + if skipUs.Contains(entry.HtlcIndex) || !isAdd { continue } @@ -3029,7 +3029,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Again, we do the same for our peer's updates. for _, entry := range view.TheirUpdates { isAdd := entry.EntryType == Add - if _, ok := skipThem[entry.HtlcIndex]; !isAdd || ok { + if skipThem.Contains(entry.HtlcIndex) || !isAdd { continue } From 1b2cb14254acdd81975cabb7d2b2082eb1c1de3d Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Mon, 22 Jul 2024 14:34:18 -0700 Subject: [PATCH 08/14] lnwallet: change bool isIncoming to new lntypes.ChannelParty This commit removes another raw boolean value and replaces it with a more clear type/name. This will also assist us when we later try and consolidate the logic of evaluateHTLCView into a single coherent computation. --- lnwallet/channel.go | 27 ++++++++++++++++----------- lnwallet/channel_test.go | 32 ++++++++++++++++---------------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 5e6fbc5b5f..f09d848136 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2954,7 +2954,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) if rmvHeight == 0 { processRemoveEntry( - entry, ourBalance, theirBalance, true, + entry, ourBalance, theirBalance, + lntypes.Remote, ) } } @@ -2998,7 +2999,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) if rmvHeight == 0 { processRemoveEntry( - entry, ourBalance, theirBalance, false, + entry, ourBalance, theirBalance, lntypes.Local, ) } } @@ -3019,7 +3020,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) if addHeight == 0 { processAddEntry( - entry, ourBalance, theirBalance, false, + entry, ourBalance, theirBalance, lntypes.Local, ) } @@ -3040,7 +3041,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) if addHeight == 0 { processAddEntry( - entry, ourBalance, theirBalance, true, + entry, ourBalance, theirBalance, lntypes.Remote, ) } @@ -3131,9 +3132,9 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, // was committed is updated. Keeping track of this inclusion height allows us to // later compact the log once the change is fully committed in both chains. func processAddEntry(htlc *paymentDescriptor, ourBalance, - theirBalance *lnwire.MilliSatoshi, isIncoming bool) { + theirBalance *lnwire.MilliSatoshi, originator lntypes.ChannelParty) { - if isIncoming { + if originator == lntypes.Remote { // If this is a new incoming (un-committed) HTLC, then we need // to update their balance accordingly by subtracting the // amount of the HTLC that are funds pending. @@ -3149,31 +3150,35 @@ func processAddEntry(htlc *paymentDescriptor, ourBalance, // previously added HTLC. If the removal entry has already been processed, it // is skipped. func processRemoveEntry(htlc *paymentDescriptor, ourBalance, - theirBalance *lnwire.MilliSatoshi, isIncoming bool) { + theirBalance *lnwire.MilliSatoshi, originator lntypes.ChannelParty) { switch { // If an incoming HTLC is being settled, then this means that we've // received the preimage either from another subsystem, or the // upstream peer in the route. Therefore, we increase our balance by // the HTLC amount. - case isIncoming && htlc.EntryType == Settle: + case originator == lntypes.Remote && htlc.EntryType == Settle: *ourBalance += htlc.Amount // Otherwise, this HTLC is being failed out, therefore the value of the // HTLC should return to the remote party. - case isIncoming && (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): + case originator == lntypes.Remote && + (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): + *theirBalance += htlc.Amount // If an outgoing HTLC is being settled, then this means that the // downstream party resented the preimage or learned of it via a // downstream peer. In either case, we credit their settled value with // the value of the HTLC. - case !isIncoming && htlc.EntryType == Settle: + case originator == lntypes.Local && htlc.EntryType == Settle: *theirBalance += htlc.Amount // Otherwise, one of our outgoing HTLC's has timed out, so the value of // the HTLC should be returned to our settled balance. - case !isIncoming && (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): + case originator == lntypes.Local && + (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): + *ourBalance += htlc.Amount } } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 0372e4d166..0bac824a2d 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9089,7 +9089,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { name string startHeights heights whoseCommitChain lntypes.ChannelParty - isIncoming bool + originator lntypes.ChannelParty mutateState bool ourExpectedBalance lnwire.MilliSatoshi theirExpectedBalance lnwire.MilliSatoshi @@ -9105,7 +9105,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance, @@ -9126,7 +9126,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Local, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance, @@ -9147,7 +9147,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Local, - isIncoming: true, + originator: lntypes.Remote, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance - updateAmount, @@ -9168,7 +9168,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Local, - isIncoming: true, + originator: lntypes.Remote, mutateState: true, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance - updateAmount, @@ -9190,7 +9190,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance - updateAmount, theirExpectedBalance: startBalance, @@ -9211,7 +9211,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: false, + originator: lntypes.Local, mutateState: true, ourExpectedBalance: startBalance - updateAmount, theirExpectedBalance: startBalance, @@ -9232,7 +9232,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: removeHeight, }, whoseCommitChain: lntypes.Remote, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance, @@ -9253,7 +9253,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Local, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance, @@ -9276,7 +9276,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: true, + originator: lntypes.Remote, mutateState: false, ourExpectedBalance: startBalance + updateAmount, theirExpectedBalance: startBalance, @@ -9299,7 +9299,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance + updateAmount, @@ -9322,7 +9322,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: true, + originator: lntypes.Remote, mutateState: false, ourExpectedBalance: startBalance, theirExpectedBalance: startBalance + updateAmount, @@ -9345,7 +9345,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: false, + originator: lntypes.Local, mutateState: false, ourExpectedBalance: startBalance + updateAmount, theirExpectedBalance: startBalance, @@ -9370,7 +9370,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Local, - isIncoming: true, + originator: lntypes.Remote, mutateState: true, ourExpectedBalance: startBalance + updateAmount, theirExpectedBalance: startBalance, @@ -9395,7 +9395,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { remoteRemove: 0, }, whoseCommitChain: lntypes.Remote, - isIncoming: true, + originator: lntypes.Remote, mutateState: true, ourExpectedBalance: startBalance + updateAmount, theirExpectedBalance: startBalance, @@ -9450,7 +9450,7 @@ func TestProcessAddRemoveEntry(t *testing.T) { if heightDual.GetForParty(test.whoseCommitChain) == 0 { process( update, &ourBalance, &theirBalance, - test.isIncoming, + test.originator, ) if test.mutateState { From 4b2a4e36adeea1437265daedd4e924d114ff5607 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 24 Jul 2024 14:57:32 -0700 Subject: [PATCH 09/14] lnwallet: pack htlcView.{OurUpdates|TheirUpdates} into Dual. This commit moves the collection of updates behind a Dual structure. This allows us in a later commit to index into it via a ChannelParty parameter which will simplify the loops in evaluateHTLCView. --- lnwallet/channel.go | 61 +++++++++++++++++++++------------------- lnwallet/channel_test.go | 13 +++++---- lnwallet/commitment.go | 8 +++--- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index f09d848136..b66b14780d 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2675,11 +2675,8 @@ type HtlcView struct { // created using this view. NextHeight uint64 - // OurUpdates are our outgoing HTLCs. - OurUpdates []*paymentDescriptor - - // TheirUpdates are their incoming HTLCs. - TheirUpdates []*paymentDescriptor + // Updates is a Dual of the Local and Remote HTLCs. + Updates lntypes.Dual[[]*paymentDescriptor] // FeePerKw is the fee rate in sat/kw of the commitment transaction. FeePerKw chainfee.SatPerKWeight @@ -2688,13 +2685,13 @@ type HtlcView struct { // AuxOurUpdates returns the outgoing HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxOurUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.OurUpdates) + return fn.Map(newAuxHtlcDescriptor, v.Updates.Local) } // AuxTheirUpdates returns the incoming HTLCs as a read-only copy of // AuxHtlcDescriptors. func (v *HtlcView) AuxTheirUpdates() []AuxHtlcDescriptor { - return fn.Map(newAuxHtlcDescriptor, v.TheirUpdates) + return fn.Map(newAuxHtlcDescriptor, v.Updates.Remote) } // fetchHTLCView returns all the candidate HTLC updates which should be @@ -2728,8 +2725,10 @@ func (lc *LightningChannel) fetchHTLCView(theirLogIndex, } return &HtlcView{ - OurUpdates: ourHTLCs, - TheirUpdates: theirHTLCs, + Updates: lntypes.Dual[[]*paymentDescriptor]{ + Local: ourHTLCs, + Remote: theirHTLCs, + }, } } @@ -2853,15 +2852,15 @@ func (lc *LightningChannel) fetchCommitmentView( // commitment are mutated, we'll manually copy over each HTLC to its // respective slice. c.outgoingHTLCs = make( - []paymentDescriptor, len(filteredHTLCView.OurUpdates), + []paymentDescriptor, len(filteredHTLCView.Updates.Local), ) - for i, htlc := range filteredHTLCView.OurUpdates { + for i, htlc := range filteredHTLCView.Updates.Local { c.outgoingHTLCs[i] = *htlc } c.incomingHTLCs = make( - []paymentDescriptor, len(filteredHTLCView.TheirUpdates), + []paymentDescriptor, len(filteredHTLCView.Updates.Remote), ) - for i, htlc := range filteredHTLCView.TheirUpdates { + for i, htlc := range filteredHTLCView.Updates.Remote { c.incomingHTLCs[i] = *htlc } @@ -2916,7 +2915,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // First we run through non-add entries in both logs, populating the // skip sets. - for _, entry := range view.OurUpdates { + for _, entry := range view.Updates.Local { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -2961,7 +2960,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, } // Do the same for our peer's updates. - for _, entry := range view.TheirUpdates { + for _, entry := range view.Updates.Remote { switch entry.EntryType { // Skip adds for now. They will be processed below. case Add: @@ -3007,7 +3006,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.OurUpdates { + for _, entry := range view.Updates.Local { isAdd := entry.EntryType == Add if skipUs.Contains(entry.HtlcIndex) || !isAdd { continue @@ -3024,11 +3023,11 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) } - newView.OurUpdates = append(newView.OurUpdates, entry) + newView.Updates.Local = append(newView.Updates.Local, entry) } // Again, we do the same for our peer's updates. - for _, entry := range view.TheirUpdates { + for _, entry := range view.Updates.Remote { isAdd := entry.EntryType == Add if skipThem.Contains(entry.HtlcIndex) || !isAdd { continue @@ -3045,7 +3044,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, ) } - newView.TheirUpdates = append(newView.TheirUpdates, entry) + newView.Updates.Remote = append(newView.Updates.Remote, entry) } // Create a function that is capable of identifying whether or not the @@ -3075,10 +3074,12 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // Collect all of the updates that haven't had their commit heights set // for the commitment chain corresponding to whoseCommitmentChain. - uncommittedUpdates := lntypes.Dual[[]*paymentDescriptor]{ - Local: fn.Filter(isUncommitted, view.OurUpdates), - Remote: fn.Filter(isUncommitted, view.TheirUpdates), - } + uncommittedUpdates := lntypes.MapDual( + view.Updates, + func(us []*paymentDescriptor) []*paymentDescriptor { + return fn.Filter(isUncommitted, us) + }, + ) return newView, uncommittedUpdates, nil } @@ -3746,10 +3747,12 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // appropriate update log, in order to validate the sanity of the // commitment resulting from _actually adding_ this HTLC to the state. if predictOurAdd != nil { - view.OurUpdates = append(view.OurUpdates, predictOurAdd) + view.Updates.Local = append(view.Updates.Local, predictOurAdd) } if predictTheirAdd != nil { - view.TheirUpdates = append(view.TheirUpdates, predictTheirAdd) + view.Updates.Remote = append( + view.Updates.Remote, predictTheirAdd, + ) } ourBalance, theirBalance, commitWeight, filteredView, err := lc.computeView( @@ -3904,7 +3907,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // First check that the remote updates won't violate it's channel // constraints. err = validateUpdates( - filteredView.TheirUpdates, &lc.channelState.RemoteChanCfg, + filteredView.Updates.Remote, &lc.channelState.RemoteChanCfg, ) if err != nil { return err @@ -3913,7 +3916,7 @@ func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter, // Secondly check that our updates won't violate our channel // constraints. err = validateUpdates( - filteredView.OurUpdates, &lc.channelState.LocalChanCfg, + filteredView.Updates.Local, &lc.channelState.LocalChanCfg, ) if err != nil { return err @@ -4700,7 +4703,7 @@ func (lc *LightningChannel) computeView(view *HtlcView, // Now go through all HTLCs at this stage, to calculate the total // weight, needed to calculate the transaction fee. var totalHtlcWeight lntypes.WeightUnit - for _, htlc := range filteredHTLCView.OurUpdates { + for _, htlc := range filteredHTLCView.Updates.Local { if HtlcIsDust( lc.channelState.ChanType, false, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -4711,7 +4714,7 @@ func (lc *LightningChannel) computeView(view *HtlcView, totalHtlcWeight += input.HTLCWeight } - for _, htlc := range filteredHTLCView.TheirUpdates { + for _, htlc := range filteredHTLCView.Updates.Remote { if HtlcIsDust( lc.channelState.ChanType, true, whoseCommitChain, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 0bac824a2d..dd352dab18 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8941,9 +8941,11 @@ func TestEvaluateView(t *testing.T) { } view := &HtlcView{ - OurUpdates: test.ourHtlcs, - TheirUpdates: test.theirHtlcs, - FeePerKw: feePerKw, + Updates: lntypes.Dual[[]*paymentDescriptor]{ + Local: test.ourHtlcs, + Remote: test.theirHtlcs, + }, + FeePerKw: feePerKw, } var ( @@ -8996,11 +8998,12 @@ func TestEvaluateView(t *testing.T) { } checkExpectedHtlcs( - t, result.OurUpdates, test.ourExpectedHtlcs, + t, result.Updates.Local, test.ourExpectedHtlcs, ) checkExpectedHtlcs( - t, result.TheirUpdates, test.theirExpectedHtlcs, + t, result.Updates.Remote, + test.theirExpectedHtlcs, ) if lc.channelState.TotalMSatSent != test.expectSent { diff --git a/lnwallet/commitment.go b/lnwallet/commitment.go index 170efece1b..6d61729a41 100644 --- a/lnwallet/commitment.go +++ b/lnwallet/commitment.go @@ -702,7 +702,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, } numHTLCs := int64(0) - for _, htlc := range filteredHTLCView.OurUpdates { + for _, htlc := range filteredHTLCView.Updates.Local { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -713,7 +713,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, numHTLCs++ } - for _, htlc := range filteredHTLCView.TheirUpdates { + for _, htlc := range filteredHTLCView.Updates.Remote { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -827,7 +827,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, // purposes of sorting. cltvs := make([]uint32, len(commitTx.TxOut)) htlcIndexes := make([]input.HtlcIndex, len(commitTx.TxOut)) - for _, htlc := range filteredHTLCView.OurUpdates { + for _, htlc := range filteredHTLCView.Updates.Local { if HtlcIsDust( cb.chanState.ChanType, false, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, @@ -855,7 +855,7 @@ func (cb *CommitmentBuilder) createUnsignedCommitmentTx(ourBalance, cltvs = append(cltvs, htlc.Timeout) //nolint htlcIndexes = append(htlcIndexes, htlc.HtlcIndex) //nolint } - for _, htlc := range filteredHTLCView.TheirUpdates { + for _, htlc := range filteredHTLCView.Updates.Remote { if HtlcIsDust( cb.chanState.ChanType, true, whoseCommit, feePerKw, htlc.Amount.ToSatoshis(), dustLimit, From 51c28496a4c27d08e599a4482052046e7ce273a1 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 24 Jul 2024 15:52:41 -0700 Subject: [PATCH 10/14] lnwallet: simplify fee calculation in evaluateHTLCView This commit simplifies how we compute the commitment fee rate based off of the live updates. Prior to this commit we processed all of the FeeUpdate paymentDescriptors of both ChannelParty's. Now we only process the last FeeUpdate of the OpeningParty --- go.mod | 2 +- go.sum | 4 ++-- lnwallet/channel.go | 43 ++++++++++++++-------------------------- lnwallet/channel_test.go | 5 +++++ 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index 60af5fb3de..3a559abb68 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn v1.2.1 + github.com/lightningnetwork/lnd/fn v1.2.2 github.com/lightningnetwork/lnd/healthcheck v1.2.5 github.com/lightningnetwork/lnd/kvdb v1.4.10 github.com/lightningnetwork/lnd/queue v1.1.1 diff --git a/go.sum b/go.sum index 312b225780..ea51f3e958 100644 --- a/go.sum +++ b/go.sum @@ -453,8 +453,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn v1.2.1 h1:pPsVGrwi9QBwdLJzaEGK33wmiVKOxs/zc8H7+MamFf0= -github.com/lightningnetwork/lnd/fn v1.2.1/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= +github.com/lightningnetwork/lnd/fn v1.2.2 h1:rVtmGW1cQTmYce2XdUbRcc5qLDxqu+aQ6IGRpyspakk= +github.com/lightningnetwork/lnd/fn v1.2.2/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= github.com/lightningnetwork/lnd/healthcheck v1.2.5 h1:aTJy5xeBpcWgRtW/PGBDe+LMQEmNm/HQewlQx2jt7OA= github.com/lightningnetwork/lnd/healthcheck v1.2.5/go.mod h1:G7Tst2tVvWo7cx6mSBEToQC5L1XOGxzZTPB29g9Rv2I= github.com/lightningnetwork/lnd/kvdb v1.4.10 h1:vK89IVv1oVH9ubQWU+EmoCQFeVRaC8kfmOrqHbY5zoY= diff --git a/lnwallet/channel.go b/lnwallet/channel.go index b66b14780d..906c37e79d 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2906,6 +2906,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, NextHeight: nextHeight, } + // The fee rate of our view is always the last UpdateFee message from + // the channel's OpeningParty. + openerUpdates := view.Updates.GetForParty(lc.channelState.Initiator()) + feeUpdates := fn.Filter(func(u *paymentDescriptor) bool { + return u.EntryType == FeeUpdate + }, openerUpdates) + lastFeeUpdate := fn.Last(feeUpdates) + lastFeeUpdate.WhenSome(func(pd *paymentDescriptor) { + newView.FeePerKw = chainfee.SatPerKWeight( + pd.Amount.ToSatoshis(), + ) + }) + // We use two maps, one for the local log and one for the remote log to // keep track of which entries we need to skip when creating the final // htlc view. We skip an entry whenever we find a settle or a timeout @@ -2921,21 +2934,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, case Add: continue - // Process fee updates, updating the current feePerKw. + // Skip fee updates because we've already dealt with them above. case FeeUpdate: - h := entry.addCommitHeights.GetForParty( - whoseCommitChain, - ) - - if h == 0 { - // If the update wasn't already locked in, - // update the current fee rate to reflect this - // update. - newView.FeePerKw = chainfee.SatPerKWeight( - entry.Amount.ToSatoshis(), - ) - } - continue } @@ -2966,21 +2966,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, case Add: continue - // Process fee updates, updating the current feePerKw. + // Skip fee updates because we've already dealt with them above. case FeeUpdate: - h := entry.addCommitHeights.GetForParty( - whoseCommitChain, - ) - - if h == 0 { - // If the update wasn't already locked in, - // update the current fee rate to reflect this - // update. - newView.FeePerKw = chainfee.SatPerKWeight( - entry.Amount.ToSatoshis(), - ) - } - continue } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index dd352dab18..62eb01e3e9 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8650,6 +8650,7 @@ func TestEvaluateView(t *testing.T) { name string ourHtlcs []*paymentDescriptor theirHtlcs []*paymentDescriptor + channelInitiator lntypes.ChannelParty whoseCommitChain lntypes.ChannelParty mutateState bool @@ -8679,6 +8680,7 @@ func TestEvaluateView(t *testing.T) { }{ { name: "our fee update is applied", + channelInitiator: lntypes.Local, whoseCommitChain: lntypes.Local, mutateState: false, ourHtlcs: []*paymentDescriptor{ @@ -8696,6 +8698,7 @@ func TestEvaluateView(t *testing.T) { }, { name: "their fee update is applied", + channelInitiator: lntypes.Remote, whoseCommitChain: lntypes.Local, mutateState: false, ourHtlcs: []*paymentDescriptor{}, @@ -8911,8 +8914,10 @@ func TestEvaluateView(t *testing.T) { test := test t.Run(test.name, func(t *testing.T) { + isInitiator := test.channelInitiator == lntypes.Local lc := LightningChannel{ channelState: &channeldb.OpenChannel{ + IsInitiator: isInitiator, TotalMSatSent: 0, TotalMSatReceived: 0, }, From d5aab4a8c1ab4b9985bae25b40b2c57780bceb92 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 24 Jul 2024 16:16:31 -0700 Subject: [PATCH 11/14] lnwallet: consolidate loops in evaluateHTLCView We had four for-loops in evaluateHTLCView that were exact mirror images of each other. By making use of the new ChannelParty and Dual facilities introduced in prior commits, we consolidate these into two for-loops. --- lnwallet/channel.go | 146 ++++++++++++++++---------------------------- 1 file changed, 54 insertions(+), 92 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 906c37e79d..ae6f1b9ecb 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2905,6 +2905,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, FeePerKw: view.FeePerKw, NextHeight: nextHeight, } + noUncommitted := lntypes.Dual[[]*paymentDescriptor]{} // The fee rate of our view is always the last UpdateFee message from // the channel's OpeningParty. @@ -2923,115 +2924,76 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // keep track of which entries we need to skip when creating the final // htlc view. We skip an entry whenever we find a settle or a timeout // modifying an entry. - skipUs := fn.NewSet[uint64]() - skipThem := fn.NewSet[uint64]() - - // First we run through non-add entries in both logs, populating the - // skip sets. - for _, entry := range view.Updates.Local { - switch entry.EntryType { - // Skip adds for now. They will be processed below. - case Add: - continue - - // Skip fee updates because we've already dealt with them above. - case FeeUpdate: - continue - } - - addEntry, err := lc.fetchParent( - entry, whoseCommitChain, lntypes.Remote, - ) - if err != nil { - return nil, lntypes.Dual[[]*paymentDescriptor]{}, err - } + skip := lntypes.Dual[fn.Set[uint64]]{ + Local: fn.NewSet[uint64](), + Remote: fn.NewSet[uint64](), + } + + parties := [2]lntypes.ChannelParty{lntypes.Local, lntypes.Remote} + for _, party := range parties { + // First we run through non-add entries in both logs, + // populating the skip sets. + for _, entry := range view.Updates.GetForParty(party) { + switch entry.EntryType { + // Skip adds for now. They will be processed below. + case Add: + continue - skipThem.Add(addEntry.HtlcIndex) + // Skip fee updates because we've already dealt with + // them above. + case FeeUpdate: + continue + } - rmvHeight := entry.removeCommitHeights.GetForParty( - whoseCommitChain, - ) - if rmvHeight == 0 { - processRemoveEntry( - entry, ourBalance, theirBalance, - lntypes.Remote, + addEntry, err := lc.fetchParent( + entry, whoseCommitChain, party.CounterParty(), ) - } - } - - // Do the same for our peer's updates. - for _, entry := range view.Updates.Remote { - switch entry.EntryType { - // Skip adds for now. They will be processed below. - case Add: - continue - - // Skip fee updates because we've already dealt with them above. - case FeeUpdate: - continue - } - - addEntry, err := lc.fetchParent( - entry, whoseCommitChain, lntypes.Local, - ) - if err != nil { - return nil, lntypes.Dual[[]*paymentDescriptor]{}, err - } + if err != nil { + return nil, noUncommitted, err + } - skipUs.Add(addEntry.HtlcIndex) + skipSet := skip.GetForParty(party.CounterParty()) + skipSet.Add(addEntry.HtlcIndex) - rmvHeight := entry.removeCommitHeights.GetForParty( - whoseCommitChain, - ) - if rmvHeight == 0 { - processRemoveEntry( - entry, ourBalance, theirBalance, lntypes.Local, + rmvHeight := entry.removeCommitHeights.GetForParty( + whoseCommitChain, ) + if rmvHeight == 0 { + processRemoveEntry( + entry, ourBalance, theirBalance, + party.CounterParty(), + ) + } } } // Next we take a second pass through all the log entries, skipping any // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. - for _, entry := range view.Updates.Local { - isAdd := entry.EntryType == Add - if skipUs.Contains(entry.HtlcIndex) || !isAdd { - continue - } + for _, party := range parties { + for _, entry := range view.Updates.GetForParty(party) { + isAdd := entry.EntryType == Add + skipSet := skip.GetForParty(party) + if skipSet.Contains(entry.HtlcIndex) || !isAdd { + continue + } - // Skip the entries that have already had their add commit - // height set for this commit chain. - addHeight := entry.addCommitHeights.GetForParty( - whoseCommitChain, - ) - if addHeight == 0 { - processAddEntry( - entry, ourBalance, theirBalance, lntypes.Local, + // Skip the entries that have already had their add + // commit height set for this commit chain. + addHeight := entry.addCommitHeights.GetForParty( + whoseCommitChain, ) - } - - newView.Updates.Local = append(newView.Updates.Local, entry) - } - - // Again, we do the same for our peer's updates. - for _, entry := range view.Updates.Remote { - isAdd := entry.EntryType == Add - if skipThem.Contains(entry.HtlcIndex) || !isAdd { - continue - } + if addHeight == 0 { + processAddEntry( + entry, ourBalance, theirBalance, party, + ) + } - // Skip the entries that have already had their add commit - // height set for this commit chain. - addHeight := entry.addCommitHeights.GetForParty( - whoseCommitChain, - ) - if addHeight == 0 { - processAddEntry( - entry, ourBalance, theirBalance, lntypes.Remote, + prevUpdates := newView.Updates.GetForParty(party) + newView.Updates.SetForParty( + party, append(prevUpdates, entry), ) } - - newView.Updates.Remote = append(newView.Updates.Remote, entry) } // Create a function that is capable of identifying whether or not the From 1222cb8b10afa0c31d286e56a883921582e76397 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 24 Jul 2024 16:30:55 -0700 Subject: [PATCH 12/14] lnwallet: remove continue statements from evaluateHTLCView loops This further reduces loop complexity in evaluateHTLCView by using explicit filter steps rather than loop continue statements. --- lnwallet/channel.go | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index ae6f1b9ecb..60c7faa218 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2933,18 +2933,16 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, for _, party := range parties { // First we run through non-add entries in both logs, // populating the skip sets. - for _, entry := range view.Updates.GetForParty(party) { - switch entry.EntryType { - // Skip adds for now. They will be processed below. - case Add: - continue - - // Skip fee updates because we've already dealt with - // them above. - case FeeUpdate: - continue + resolutions := fn.Filter(func(pd *paymentDescriptor) bool { + switch pd.EntryType { + case Settle, Fail, MalformedFail: + return true + default: + return false } + }, view.Updates.GetForParty(party)) + for _, entry := range resolutions { addEntry, err := lc.fetchParent( entry, whoseCommitChain, party.CounterParty(), ) @@ -2971,13 +2969,12 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, // settled HTLCs, and debiting the chain state balance due to any newly // added HTLCs. for _, party := range parties { - for _, entry := range view.Updates.GetForParty(party) { - isAdd := entry.EntryType == Add - skipSet := skip.GetForParty(party) - if skipSet.Contains(entry.HtlcIndex) || !isAdd { - continue - } + liveAdds := fn.Filter(func(pd *paymentDescriptor) bool { + return pd.EntryType == Add && + !skip.GetForParty(party).Contains(pd.HtlcIndex) + }, view.Updates.GetForParty(party)) + for _, entry := range liveAdds { // Skip the entries that have already had their add // commit height set for this commit chain. addHeight := entry.addCommitHeights.GetForParty( @@ -2988,12 +2985,9 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, entry, ourBalance, theirBalance, party, ) } - - prevUpdates := newView.Updates.GetForParty(party) - newView.Updates.SetForParty( - party, append(prevUpdates, entry), - ) } + + newView.Updates.SetForParty(party, liveAdds) } // Create a function that is capable of identifying whether or not the From f0eecfa2cd144fa776e74dd17c8d07d25623f280 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 24 Jul 2024 17:24:08 -0700 Subject: [PATCH 13/14] lnwallet: inline and remove process[Add|Remove]Entry This commit observes that processAddEntry and processRemoveEntry are only invoked at a single call-site. Here we inline them at their call-sites, which will unlock further simplifications of the code that will allow us to remove pointer mutations in favor of explicit expression oriented programming. We also delete the tests associated with these functions, the overall functionality is implicitly tested by the TestEvaluateHTLCView tests. --- lnwallet/channel.go | 114 +++++------ lnwallet/channel_test.go | 431 --------------------------------------- 2 files changed, 51 insertions(+), 494 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 60c7faa218..f5666dabf8 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2957,10 +2957,44 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, whoseCommitChain, ) if rmvHeight == 0 { - processRemoveEntry( - entry, ourBalance, theirBalance, - party.CounterParty(), - ) + switch { + // If an incoming HTLC is being settled, then + // this means that we've received the preimage + // either from another subsystem, or the + // upstream peer in the route. Therefore, we + // increase our balance by the HTLC amount. + case party.CounterParty() == lntypes.Remote && + entry.EntryType == Settle: + + *ourBalance += entry.Amount + + // Otherwise, this HTLC is being failed out, + // therefore the value of the HTLC should + // return to the remote party. + case party.CounterParty() == lntypes.Remote && + entry.EntryType != Settle: + + *theirBalance += entry.Amount + + // If an outgoing HTLC is being settled, then + // this means that the downstream party + // resented the preimage or learned of it via a + // downstream peer. In either case, we credit + // their settled value with the value of the + // HTLC. + case party.CounterParty() == lntypes.Local && + entry.EntryType == Settle: + + *theirBalance += entry.Amount + + // Otherwise, one of our outgoing HTLC's has + // timed out, so the value of the HTLC should + // be returned to our settled balance. + case party.CounterParty() == lntypes.Local && + entry.EntryType != Settle: + + *ourBalance += entry.Amount + } } } } @@ -2981,9 +3015,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, whoseCommitChain, ) if addHeight == 0 { - processAddEntry( - entry, ourBalance, theirBalance, party, - ) + if party == lntypes.Remote { + // If this is a new incoming + // (un-committed) HTLC, then we need + // to update their balance accordingly + // by subtracting the amount of the + // HTLC that are funds pending. + *theirBalance -= entry.Amount + } else { + // Similarly, we need to debit our + // balance if this is an out going HTLC + // to reflect the pending balance. + *ourBalance -= entry.Amount + } } } @@ -3071,62 +3115,6 @@ func (lc *LightningChannel) fetchParent(entry *paymentDescriptor, return addEntry, nil } -// processAddEntry evaluates the effect of an add entry within the HTLC log. -// If the HTLC hasn't yet been committed in either chain, then the height it -// was committed is updated. Keeping track of this inclusion height allows us to -// later compact the log once the change is fully committed in both chains. -func processAddEntry(htlc *paymentDescriptor, ourBalance, - theirBalance *lnwire.MilliSatoshi, originator lntypes.ChannelParty) { - - if originator == lntypes.Remote { - // If this is a new incoming (un-committed) HTLC, then we need - // to update their balance accordingly by subtracting the - // amount of the HTLC that are funds pending. - *theirBalance -= htlc.Amount - } else { - // Similarly, we need to debit our balance if this is an out - // going HTLC to reflect the pending balance. - *ourBalance -= htlc.Amount - } -} - -// processRemoveEntry processes a log entry which settles or times out a -// previously added HTLC. If the removal entry has already been processed, it -// is skipped. -func processRemoveEntry(htlc *paymentDescriptor, ourBalance, - theirBalance *lnwire.MilliSatoshi, originator lntypes.ChannelParty) { - - switch { - // If an incoming HTLC is being settled, then this means that we've - // received the preimage either from another subsystem, or the - // upstream peer in the route. Therefore, we increase our balance by - // the HTLC amount. - case originator == lntypes.Remote && htlc.EntryType == Settle: - *ourBalance += htlc.Amount - - // Otherwise, this HTLC is being failed out, therefore the value of the - // HTLC should return to the remote party. - case originator == lntypes.Remote && - (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): - - *theirBalance += htlc.Amount - - // If an outgoing HTLC is being settled, then this means that the - // downstream party resented the preimage or learned of it via a - // downstream peer. In either case, we credit their settled value with - // the value of the HTLC. - case originator == lntypes.Local && htlc.EntryType == Settle: - *theirBalance += htlc.Amount - - // Otherwise, one of our outgoing HTLC's has timed out, so the value of - // the HTLC should be returned to our settled balance. - case originator == lntypes.Local && - (htlc.EntryType == Fail || htlc.EntryType == MalformedFail): - - *ourBalance += htlc.Amount - } -} - // generateRemoteHtlcSigJobs generates a series of HTLC signature jobs for the // sig pool, along with a channel that if closed, will cancel any jobs after // they have been submitted to the sigPool. This method is to be used when diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 62eb01e3e9..f99fb141d0 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9055,437 +9055,6 @@ type heights struct { remoteRemove uint64 } -func checkHeights(t *testing.T, update *paymentDescriptor, expected heights) { - updateHeights := heights{ - localAdd: update.addCommitHeights.Local, - localRemove: update.removeCommitHeights.Local, - remoteAdd: update.addCommitHeights.Remote, - remoteRemove: update.removeCommitHeights.Remote, - } - - if !reflect.DeepEqual(updateHeights, expected) { - t.Fatalf("expected: %v, got: %v", expected, updateHeights) - } -} - -// TestProcessAddRemoveEntry tests the updating of our and their balances when -// we process adds, settles and fails. It also tests the mutating of add and -// remove heights. -func TestProcessAddRemoveEntry(t *testing.T) { - const ( - // addHeight is a non-zero addHeight that is used for htlc - // add heights. - addHeight = 100 - - // removeHeight is a non-zero removeHeight that is used for - // htlc remove heights. - removeHeight = 200 - - // nextHeight is a constant that we use for the nextHeight in - // all unit tests. - nextHeight = 400 - - // updateAmount is the amount that the update is set to. - updateAmount = lnwire.MilliSatoshi(10) - - // startBalance is a balance we start both sides out with - // so that balances can be incremented. - startBalance = lnwire.MilliSatoshi(100) - ) - - tests := []struct { - name string - startHeights heights - whoseCommitChain lntypes.ChannelParty - originator lntypes.ChannelParty - mutateState bool - ourExpectedBalance lnwire.MilliSatoshi - theirExpectedBalance lnwire.MilliSatoshi - expectedHeights heights - updateType updateType - }{ - { - name: "add, remote chain, already processed", - startHeights: heights{ - localAdd: 0, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: 0, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Add, - }, - { - name: "add, local chain, already processed", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Add, - }, - { - name: "incoming add, local chain, not mutated", - startHeights: heights{ - localAdd: 0, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - originator: lntypes.Remote, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance - updateAmount, - expectedHeights: heights{ - localAdd: 0, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Add, - }, - { - name: "incoming add, local chain, mutated", - startHeights: heights{ - localAdd: 0, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - originator: lntypes.Remote, - mutateState: true, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance - updateAmount, - expectedHeights: heights{ - localAdd: nextHeight, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Add, - }, - - { - name: "outgoing add, remote chain, not mutated", - startHeights: heights{ - localAdd: 0, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance - updateAmount, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: 0, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Add, - }, - { - name: "outgoing add, remote chain, mutated", - startHeights: heights{ - localAdd: 0, - remoteAdd: 0, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Local, - mutateState: true, - ourExpectedBalance: startBalance - updateAmount, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: 0, - remoteAdd: nextHeight, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Add, - }, - { - name: "settle, remote chain, already processed", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: removeHeight, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: removeHeight, - }, - updateType: Settle, - }, - { - name: "settle, local chain, already processed", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: removeHeight, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: removeHeight, - remoteRemove: 0, - }, - updateType: Settle, - }, - { - // Remote chain, and not processed yet. Incoming settle, - // so we expect our balance to increase. - name: "incoming settle", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Remote, - mutateState: false, - ourExpectedBalance: startBalance + updateAmount, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Settle, - }, - { - // Remote chain, and not processed yet. Incoming settle, - // so we expect our balance to increase. - name: "outgoing settle", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance + updateAmount, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Settle, - }, - { - // Remote chain, and not processed yet. Incoming fail, - // so we expect their balance to increase. - name: "incoming fail", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Remote, - mutateState: false, - ourExpectedBalance: startBalance, - theirExpectedBalance: startBalance + updateAmount, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Fail, - }, - { - // Remote chain, and not processed yet. Outgoing fail, - // so we expect our balance to increase. - name: "outgoing fail", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Local, - mutateState: false, - ourExpectedBalance: startBalance + updateAmount, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - updateType: Fail, - }, - { - // Local chain, and not processed yet. Incoming settle, - // so we expect our balance to increase. Mutate is - // true, so we expect our remove removeHeight to have - // changed. - name: "fail, our remove height mutated", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Local, - originator: lntypes.Remote, - mutateState: true, - ourExpectedBalance: startBalance + updateAmount, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: nextHeight, - remoteRemove: 0, - }, - updateType: Settle, - }, - { - // Remote chain, and not processed yet. Incoming settle, - // so we expect our balance to increase. Mutate is - // true, so we expect their remove removeHeight to have - // changed. - name: "fail, their remove height mutated", - startHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: 0, - }, - whoseCommitChain: lntypes.Remote, - originator: lntypes.Remote, - mutateState: true, - ourExpectedBalance: startBalance + updateAmount, - theirExpectedBalance: startBalance, - expectedHeights: heights{ - localAdd: addHeight, - remoteAdd: addHeight, - localRemove: 0, - remoteRemove: nextHeight, - }, - updateType: Settle, - }, - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { - t.Parallel() - - heights := test.startHeights - update := &paymentDescriptor{ - Amount: updateAmount, - addCommitHeights: lntypes.Dual[uint64]{ - Local: heights.localAdd, - Remote: heights.remoteAdd, - }, - removeCommitHeights: lntypes.Dual[uint64]{ - Local: heights.localRemove, - Remote: heights.remoteRemove, - }, - EntryType: test.updateType, - } - - var ( - // Start both parties off with an initial - // balance. Copy by value here so that we do - // not mutate the startBalance constant. - ourBalance, theirBalance = startBalance, - startBalance - ) - - // Choose the processing function we need based on the - // update type. Process remove is used for settles, - // fails and malformed htlcs. - process := processRemoveEntry - heightDual := &update.removeCommitHeights - if test.updateType == Add { - process = processAddEntry - heightDual = &update.addCommitHeights - } - - if heightDual.GetForParty(test.whoseCommitChain) == 0 { - process( - update, &ourBalance, &theirBalance, - test.originator, - ) - - if test.mutateState { - heightDual.SetForParty( - test.whoseCommitChain, - nextHeight, - ) - } - } - - // Check that balances were updated as expected. - if ourBalance != test.ourExpectedBalance { - t.Fatalf("expected our balance: %v, got: %v", - test.ourExpectedBalance, ourBalance) - } - - if theirBalance != test.theirExpectedBalance { - t.Fatalf("expected their balance: %v, got: %v", - test.theirExpectedBalance, theirBalance) - } - - // Check that heights on the update are as expected. - checkHeights(t, update, test.expectedHeights) - }) - } -} - // TestChannelUnsignedAckedFailure tests that unsigned acked updates are // properly restored after signing for them and disconnecting. // From 5307e7a5e44bd10433f02046f2befaae40284b03 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 24 Jul 2024 17:53:05 -0700 Subject: [PATCH 14/14] lnwallet: return balance changes rather than modifying references Here we return the balance deltas from evaluateHTLCView rather than passing in references to variables that will be modified. It is a far cleaner and compositional approach which allows readers of this code to more effectively reason about the code without having to keep the whole codebase in their head. --- lnwallet/channel.go | 114 ++++++++++++++++++++------------------- lnwallet/channel_test.go | 17 ++---- 2 files changed, 64 insertions(+), 67 deletions(-) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index f5666dabf8..a17678ce4c 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2893,10 +2893,9 @@ func fundingTxIn(chanState *channeldb.OpenChannel) wire.TxIn { // 1. The new htlcView reflecting the current channel state. // 2. A Dual of the updates which have not yet been committed in // 'whoseCommitChain's commitment chain. -func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, - theirBalance *lnwire.MilliSatoshi, nextHeight uint64, - whoseCommitChain lntypes.ChannelParty) (*HtlcView, - lntypes.Dual[[]*paymentDescriptor], error) { +func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, + whoseCommitChain lntypes.ChannelParty, nextHeight uint64) (*HtlcView, + lntypes.Dual[[]*paymentDescriptor], lntypes.Dual[int64], error) { // We initialize the view's fee rate to the fee rate of the unfiltered // view. If any fee updates are found when evaluating the view, it will @@ -2929,6 +2928,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, Remote: fn.NewSet[uint64](), } + balanceDeltas := lntypes.Dual[int64]{} + parties := [2]lntypes.ChannelParty{lntypes.Local, lntypes.Remote} for _, party := range parties { // First we run through non-add entries in both logs, @@ -2947,7 +2948,8 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, entry, whoseCommitChain, party.CounterParty(), ) if err != nil { - return nil, noUncommitted, err + noDeltas := lntypes.Dual[int64]{} + return nil, noUncommitted, noDeltas, err } skipSet := skip.GetForParty(party.CounterParty()) @@ -2959,41 +2961,30 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, if rmvHeight == 0 { switch { // If an incoming HTLC is being settled, then - // this means that we've received the preimage - // either from another subsystem, or the - // upstream peer in the route. Therefore, we - // increase our balance by the HTLC amount. - case party.CounterParty() == lntypes.Remote && - entry.EntryType == Settle: - - *ourBalance += entry.Amount + // this means that the preimage has been + // received by the settling party Therefore, we + // increase the settling party's balance by the + // HTLC amount. + case entry.EntryType == Settle: + delta := int64(entry.Amount) + balanceDeltas.ModifyForParty( + party, + func(acc int64) int64 { + return acc + delta + }, + ) // Otherwise, this HTLC is being failed out, // therefore the value of the HTLC should - // return to the remote party. - case party.CounterParty() == lntypes.Remote && - entry.EntryType != Settle: - - *theirBalance += entry.Amount - - // If an outgoing HTLC is being settled, then - // this means that the downstream party - // resented the preimage or learned of it via a - // downstream peer. In either case, we credit - // their settled value with the value of the - // HTLC. - case party.CounterParty() == lntypes.Local && - entry.EntryType == Settle: - - *theirBalance += entry.Amount - - // Otherwise, one of our outgoing HTLC's has - // timed out, so the value of the HTLC should - // be returned to our settled balance. - case party.CounterParty() == lntypes.Local && - entry.EntryType != Settle: - - *ourBalance += entry.Amount + // return to the failing party's counterparty. + case entry.EntryType != Settle: + delta := int64(entry.Amount) + balanceDeltas.ModifyForParty( + party.CounterParty(), + func(acc int64) int64 { + return acc + delta + }, + ) } } } @@ -3015,19 +3006,19 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, whoseCommitChain, ) if addHeight == 0 { - if party == lntypes.Remote { - // If this is a new incoming - // (un-committed) HTLC, then we need - // to update their balance accordingly - // by subtracting the amount of the - // HTLC that are funds pending. - *theirBalance -= entry.Amount - } else { - // Similarly, we need to debit our - // balance if this is an out going HTLC - // to reflect the pending balance. - *ourBalance -= entry.Amount - } + // If this is a new incoming (un-committed) + // HTLC, then we need to update their balance + // accordingly by subtracting the amount of + // the HTLC that are funds pending. + // Similarly, we need to debit our balance if + // this is an out going HTLC to reflect the + // pending balance. + balanceDeltas.ModifyForParty( + party, + func(acc int64) int64 { + return acc - int64(entry.Amount) + }, + ) } } @@ -3068,7 +3059,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *HtlcView, ourBalance, }, ) - return newView, uncommittedUpdates, nil + return newView, uncommittedUpdates, balanceDeltas, nil } // fetchParent is a helper that looks up update log parent entries in the @@ -4589,13 +4580,26 @@ func (lc *LightningChannel) computeView(view *HtlcView, // channel constraints to the final commitment state. If any fee // updates are found in the logs, the commitment fee rate should be // changed, so we'll also set the feePerKw to this new value. - filteredHTLCView, uncommitted, err := lc.evaluateHTLCView( - view, &ourBalance, &theirBalance, nextHeight, whoseCommitChain, + filteredHTLCView, uncommitted, deltas, err := lc.evaluateHTLCView( + view, whoseCommitChain, nextHeight, ) if err != nil { return 0, 0, 0, nil, err } + // Add the balance deltas to the balances we got from the commitment + // state. + if deltas.Local >= 0 { + ourBalance += lnwire.MilliSatoshi(deltas.Local) + } else { + ourBalance -= lnwire.MilliSatoshi(-1 * deltas.Local) + } + if deltas.Remote >= 0 { + theirBalance += lnwire.MilliSatoshi(deltas.Remote) + } else { + theirBalance -= lnwire.MilliSatoshi(-1 * deltas.Remote) + } + if updateState { for _, party := range lntypes.BothParties { for _, u := range uncommitted.GetForParty(party) { @@ -4619,8 +4623,8 @@ func (lc *LightningChannel) computeView(view *HtlcView, } // We need to first check ourBalance and theirBalance to be negative - // because MilliSathoshi is a unsigned type and can underflow in - // `evaluateHTLCView`. This should never happen for views which do not + // because MilliSathoshi is a unsigned type and can underflow in the + // code above. This should never happen for views which do not // include new updates (remote or local). if int64(ourBalance) < 0 { err := fmt.Errorf("%w: our balance", ErrBelowChanReserve) diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index f99fb141d0..49e4321beb 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -8953,19 +8953,12 @@ func TestEvaluateView(t *testing.T) { FeePerKw: feePerKw, } - var ( - // Create vars to store balance changes. We do - // not check these values in this test because - // balance modification happens on the htlc - // processing level. - ourBalance lnwire.MilliSatoshi - theirBalance lnwire.MilliSatoshi - ) - // Evaluate the htlc view, mutate as test expects. - result, uncommitted, err := lc.evaluateHTLCView( - view, &ourBalance, &theirBalance, nextHeight, - test.whoseCommitChain, + // We do not check the balance deltas in this test + // because balance modification happens on the htlc + // processing level. + result, uncommitted, _, err := lc.evaluateHTLCView( + view, test.whoseCommitChain, nextHeight, ) if err != nil {