diff --git a/channeldb/mp_payment.go b/channeldb/mp_payment.go index cf5669a505..b94bfd418d 100644 --- a/channeldb/mp_payment.go +++ b/channeldb/mp_payment.go @@ -10,7 +10,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" + sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) @@ -45,12 +48,19 @@ type HTLCAttemptInfo struct { // in which the payment's PaymentHash in the PaymentCreationInfo should // be used. Hash *lntypes.Hash + + // onionBlob is the cached value for onion blob created from the sphinx + // construction. + onionBlob [lnwire.OnionPacketSize]byte + + // circuit is the cached value for sphinx circuit. + circuit *sphinx.Circuit } // NewHtlcAttempt creates a htlc attempt. func NewHtlcAttempt(attemptID uint64, sessionKey *btcec.PrivateKey, route route.Route, attemptTime time.Time, - hash *lntypes.Hash) *HTLCAttempt { + hash *lntypes.Hash) (*HTLCAttempt, error) { var scratch [btcec.PrivKeyBytesLen]byte copy(scratch[:], sessionKey.Serialize()) @@ -64,7 +74,11 @@ func NewHtlcAttempt(attemptID uint64, sessionKey *btcec.PrivateKey, Hash: hash, } - return &HTLCAttempt{HTLCAttemptInfo: info} + if err := info.attachOnionBlobAndCircuit(); err != nil { + return nil, err + } + + return &HTLCAttempt{HTLCAttemptInfo: info}, nil } // SessionKey returns the ephemeral key used for a htlc attempt. This function @@ -79,6 +93,45 @@ func (h *HTLCAttemptInfo) SessionKey() *btcec.PrivateKey { return h.cachedSessionKey } +// OnionBlob returns the onion blob created from the sphinx construction. +func (h *HTLCAttemptInfo) OnionBlob() ([lnwire.OnionPacketSize]byte, error) { + var zeroBytes [lnwire.OnionPacketSize]byte + if h.onionBlob == zeroBytes { + if err := h.attachOnionBlobAndCircuit(); err != nil { + return zeroBytes, err + } + } + + return h.onionBlob, nil +} + +// Circuit returns the sphinx circuit for this attempt. +func (h *HTLCAttemptInfo) Circuit() (*sphinx.Circuit, error) { + if h.circuit == nil { + if err := h.attachOnionBlobAndCircuit(); err != nil { + return nil, err + } + } + + return h.circuit, nil +} + +// attachOnionBlobAndCircuit creates a sphinx packet and caches the onion blob +// and circuit for this attempt. +func (h *HTLCAttemptInfo) attachOnionBlobAndCircuit() error { + onionBlob, circuit, err := generateSphinxPacket( + &h.Route, h.Hash[:], h.SessionKey(), + ) + if err != nil { + return err + } + + copy(h.onionBlob[:], onionBlob) + h.circuit = circuit + + return nil +} + // HTLCAttempt contains information about a specific HTLC attempt for a given // payment. It contains the HTLCAttemptInfo used to send the HTLC, as well // as a timestamp and any known outcome of the attempt. @@ -629,3 +682,69 @@ func serializeTime(w io.Writer, t time.Time) error { _, err := w.Write(scratch[:]) return err } + +// generateSphinxPacket generates then encodes a sphinx packet which encodes +// the onion route specified by the passed layer 3 route. The blob returned +// from this function can immediately be included within an HTLC add packet to +// be sent to the first hop within the route. +func generateSphinxPacket(rt *route.Route, paymentHash []byte, + sessionKey *btcec.PrivateKey) ([]byte, *sphinx.Circuit, error) { + + // Now that we know we have an actual route, we'll map the route into a + // sphinx payment path which includes per-hop payloads for each hop + // that give each node within the route the necessary information + // (fees, CLTV value, etc.) to properly forward the payment. + sphinxPath, err := rt.ToSphinxPath() + if err != nil { + return nil, nil, err + } + + log.Tracef("Constructed per-hop payloads for payment_hash=%x: %v", + paymentHash, lnutils.NewLogClosure(func() string { + path := make( + []sphinx.OnionHop, sphinxPath.TrueRouteLength(), + ) + for i := range path { + hopCopy := sphinxPath[i] + path[i] = hopCopy + } + + return spew.Sdump(path) + }), + ) + + // Next generate the onion routing packet which allows us to perform + // privacy preserving source routing across the network. + sphinxPacket, err := sphinx.NewOnionPacket( + sphinxPath, sessionKey, paymentHash, + sphinx.DeterministicPacketFiller, + ) + if err != nil { + return nil, nil, err + } + + // Finally, encode Sphinx packet using its wire representation to be + // included within the HTLC add packet. + var onionBlob bytes.Buffer + if err := sphinxPacket.Encode(&onionBlob); err != nil { + return nil, nil, err + } + + log.Tracef("Generated sphinx packet: %v", + lnutils.NewLogClosure(func() string { + // We make a copy of the ephemeral key and unset the + // internal curve here in order to keep the logs from + // getting noisy. + key := *sphinxPacket.EphemeralKey + packetCopy := *sphinxPacket + packetCopy.EphemeralKey = &key + + return spew.Sdump(packetCopy) + }), + ) + + return onionBlob.Bytes(), &sphinx.Circuit{ + SessionKey: sessionKey, + PaymentPath: sphinxPath.NodeKeys(), + }, nil +} diff --git a/channeldb/mp_payment_test.go b/channeldb/mp_payment_test.go index 51eda72bb0..13e39871a9 100644 --- a/channeldb/mp_payment_test.go +++ b/channeldb/mp_payment_test.go @@ -5,12 +5,22 @@ import ( "fmt" "testing" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" ) +var ( + testHash = [32]byte{ + 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, + 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, + 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, + 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, + } +) + // TestLazySessionKeyDeserialize tests that we can read htlc attempt session // keys that were previously serialized as a private key as raw bytes. func TestLazySessionKeyDeserialize(t *testing.T) { @@ -578,3 +588,15 @@ func makeAttemptInfo(total, amtForwarded int) HTLCAttemptInfo { }, } } + +// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket +// function is able to gracefully handle being passed a nil set of hops for the +// route by the caller. +func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { + t.Parallel() + + sessionKey, _ := btcec.NewPrivateKey() + emptyRoute := &route.Route{} + _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) + require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) +} diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 2369d2f7e6..5e0e2bc3c7 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -115,6 +115,10 @@ var ( // amount exceed the total amount. ErrSentExceedsTotal = errors.New("total sent exceeds total amount") + // ErrRegisterAttempt is returned when a new htlc attempt cannot be + // registered. + ErrRegisterAttempt = errors.New("cannot register htlc attempt") + // errNoAttemptInfo is returned when no attempt info is stored yet. errNoAttemptInfo = errors.New("unable to find attempt info for " + "inflight payment") @@ -342,7 +346,8 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, // Check if registering a new attempt is allowed. if err := payment.Registrable(); err != nil { - return err + return fmt.Errorf("%w: %v", ErrRegisterAttempt, + err.Error()) } // If the final hop has encrypted data, then we know this is a diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index fb965bb321..f03a8b23a7 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -28,7 +28,7 @@ func genPreimage() ([32]byte, error) { return preimage, nil } -func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, +func genInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo, lntypes.Preimage, error) { preimage, err := genPreimage() @@ -38,9 +38,14 @@ func genInfo() (*PaymentCreationInfo, *HTLCAttemptInfo, } rhash := sha256.Sum256(preimage[:]) - attempt := NewHtlcAttempt( - 0, priv, *testRoute.Copy(), time.Time{}, nil, + var hash lntypes.Hash + copy(hash[:], rhash[:]) + + attempt, err := NewHtlcAttempt( + 0, priv, *testRoute.Copy(), time.Time{}, &hash, ) + require.NoError(t, err) + return &PaymentCreationInfo{ PaymentIdentifier: rhash, Value: testRoute.ReceiverAmt(), @@ -60,7 +65,7 @@ func TestPaymentControlSwitchFail(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Sends base htlc message which initiate StatusInFlight. @@ -196,7 +201,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Sends base htlc message which initiate base status and move it to @@ -266,7 +271,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { pControl := NewPaymentControl(db) - info, _, preimg, err := genInfo() + info, _, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Attempt to complete the payment should fail. @@ -291,7 +296,7 @@ func TestPaymentControlFailsWithoutInFlight(t *testing.T) { pControl := NewPaymentControl(db) - info, _, _, err := genInfo() + info, _, _, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Calling Fail should return an error. @@ -346,7 +351,7 @@ func TestPaymentControlDeleteNonInFlight(t *testing.T) { var numSuccess, numInflight int for _, p := range payments { - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to generate htlc message: %v", err) } @@ -684,7 +689,7 @@ func TestPaymentControlMultiShard(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to generate htlc message: %v", err) } @@ -836,9 +841,9 @@ func TestPaymentControlMultiShard(t *testing.T) { b.AttemptID = 3 _, err = pControl.RegisterAttempt(info.PaymentIdentifier, &b) if test.settleFirst { - require.ErrorIs(t, err, ErrPaymentPendingSettled) + require.ErrorIs(t, err, ErrRegisterAttempt) } else { - require.ErrorIs(t, err, ErrPaymentPendingFailed) + require.ErrorIs(t, err, ErrRegisterAttempt) } assertPaymentStatus(t, pControl, info.PaymentIdentifier, StatusInFlight) @@ -892,10 +897,7 @@ func TestPaymentControlMultiShard(t *testing.T) { require.NoError(t, err, "unable to fail") } - var ( - finalStatus PaymentStatus - registerErr error - ) + var finalStatus PaymentStatus switch { // If one of the attempts settled but the other failed with @@ -903,21 +905,17 @@ func TestPaymentControlMultiShard(t *testing.T) { // settled. case test.settleFirst && !test.settleLast: finalStatus = StatusSucceeded - registerErr = ErrPaymentAlreadySucceeded case !test.settleFirst && test.settleLast: finalStatus = StatusSucceeded - registerErr = ErrPaymentAlreadySucceeded // If both failed, we end up in a failed status. case !test.settleFirst && !test.settleLast: finalStatus = StatusFailed - registerErr = ErrPaymentAlreadyFailed // Otherwise, the payment has a succeed status. case test.settleFirst && test.settleLast: finalStatus = StatusSucceeded - registerErr = ErrPaymentAlreadySucceeded } assertPaymentStatus( @@ -926,7 +924,7 @@ func TestPaymentControlMultiShard(t *testing.T) { // Finally assert we cannot register more attempts. _, err = pControl.RegisterAttempt(info.PaymentIdentifier, &b) - require.Equal(t, registerErr, err) + require.ErrorIs(t, err, ErrRegisterAttempt) } for _, test := range tests { @@ -948,7 +946,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { pControl := NewPaymentControl(db) - info, attempt, _, err := genInfo() + info, attempt, _, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Init the payment. @@ -997,7 +995,7 @@ func TestPaymentControlMPPRecordValidation(t *testing.T) { // Create and init a new payment. This time we'll check that we cannot // register an MPP attempt if we already registered a non-MPP one. - info, attempt, _, err = genInfo() + info, attempt, _, err = genInfo(t) require.NoError(t, err, "unable to generate htlc message") err = pControl.InitPayment(info.PaymentIdentifier, info) @@ -1271,7 +1269,7 @@ func createTestPayments(t *testing.T, p *PaymentControl, payments []*payment) { attemptID := uint64(0) for i := 0; i < len(payments); i++ { - info, attempt, preimg, err := genInfo() + info, attempt, preimg, err := genInfo(t) require.NoError(t, err, "unable to generate htlc message") // Set the payment id accordingly in the payments slice. diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 0c3753e662..b2a0292a49 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -64,7 +64,6 @@ var ( TotalAmount: 1234567, SourcePubKey: vertex, Hops: []*route.Hop{ - testHop3, testHop2, testHop1, }, @@ -98,7 +97,7 @@ var ( } ) -func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { +func makeFakeInfo(t *testing.T) (*PaymentCreationInfo, *HTLCAttemptInfo) { var preimg lntypes.Preimage copy(preimg[:], rev[:]) @@ -113,9 +112,10 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { PaymentRequest: []byte("test"), } - a := NewHtlcAttempt( + a, err := NewHtlcAttempt( 44, priv, testRoute, time.Unix(100, 0), &hash, ) + require.NoError(t, err) return c, &a.HTLCAttemptInfo } @@ -123,7 +123,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { func TestSentPaymentSerialization(t *testing.T) { t.Parallel() - c, s := makeFakeInfo() + c, s := makeFakeInfo(t) var b bytes.Buffer require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") @@ -174,6 +174,9 @@ func TestSentPaymentSerialization(t *testing.T) { require.NoError(t, err, "deserialize") require.Equal(t, s.Route, newWireInfo.Route) + err = newWireInfo.attachOnionBlobAndCircuit() + require.NoError(t, err) + // Clear routes to allow DeepEqual to compare the remaining fields. newWireInfo.Route = route.Route{} s.Route = route.Route{} @@ -517,7 +520,7 @@ func TestQueryPayments(t *testing.T) { for i := 0; i < nonDuplicatePayments; i++ { // Generate a test payment. - info, _, preimg, err := genInfo() + info, _, preimg, err := genInfo(t) if err != nil { t.Fatalf("unable to create test "+ "payment: %v", err) @@ -618,7 +621,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { pControl := NewPaymentControl(db) // Generate a test payment which does not have duplicates. - noDuplicates, _, _, err := genInfo() + noDuplicates, _, _, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -632,7 +635,7 @@ func TestFetchPaymentWithSequenceNumber(t *testing.T) { require.NoError(t, err) // Generate a test payment which we will add duplicates to. - hasDuplicates, _, preimg, err := genInfo() + hasDuplicates, _, preimg, err := genInfo(t) require.NoError(t, err) // Create a new payment entry in the database. @@ -783,7 +786,7 @@ func putDuplicatePayment(t *testing.T, duplicateBucket kvdb.RwBucket, require.NoError(t, err) // Generate fake information for the duplicate payment. - info, _, _, err := genInfo() + info, _, _, err := genInfo(t) require.NoError(t, err) // Write the payment info to disk under the creation info key. This code diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 4fc64028d3..521d0cfce4 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -63,6 +63,11 @@ * [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9322) that caused estimateroutefee to ignore the default payment timeout. +* [Fixed an edge case](https://github.com/lightningnetwork/lnd/pull/9150) where + the payment may become stuck if the invoice times out while the node + restarts, for details check [this + issue](https://github.com/lightningnetwork/lnd/issues/8975#issuecomment-2270528222). + # New Features * [Support](https://github.com/lightningnetwork/lnd/pull/8390) for diff --git a/itest/lnd_multi-hop_force_close_test.go b/itest/lnd_multi-hop_force_close_test.go index 4284631e86..a1e5d43163 100644 --- a/itest/lnd_multi-hop_force_close_test.go +++ b/itest/lnd_multi-hop_force_close_test.go @@ -357,8 +357,11 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, // We'll create two random payment hashes unknown to carol, then send // each of them by manually specifying the HTLC details. carolPubKey := carol.PubKey[:] - dustPayHash := ht.Random32Bytes() - payHash := ht.Random32Bytes() + + preimageDust := ht.RandomPreimage() + preimage := ht.RandomPreimage() + dustPayHash := preimageDust.Hash() + payHash := preimage.Hash() // If this is a taproot channel, then we'll need to make some manual // route hints so Alice can actually find a route. @@ -370,7 +373,7 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, req := &routerrpc.SendPaymentRequest{ Dest: carolPubKey, Amt: int64(dustHtlcAmt), - PaymentHash: dustPayHash, + PaymentHash: dustPayHash[:], FinalCltvDelta: finalCltvDelta, TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, @@ -381,7 +384,7 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, req = &routerrpc.SendPaymentRequest{ Dest: carolPubKey, Amt: int64(htlcAmt), - PaymentHash: payHash, + PaymentHash: payHash[:], FinalCltvDelta: finalCltvDelta, TimeoutSeconds: 60, FeeLimitMsat: noFeeLimitMsat, @@ -532,6 +535,25 @@ func runLocalClaimOutgoingHTLC(ht *lntest.HarnessTest, // Once this transaction has been confirmed, Bob should detect that he // no longer has any pending channels. ht.AssertNumPendingForceClose(bob, 0) + + // Now that Bob has claimed his HTLCs, Alice should mark the two + // payments as failed. + // + // Alice will mark this payment as failed with no route as the only + // route she has is Alice->Bob->Carol. This won't be the case if she + // has a second route, as another attempt will be tried. + // + // TODO(yy): we should instead mark this payment as timed out if she has + // a second route to try this payment, which is the timeout set by Alice + // when sending the payment. + expectedReason := lnrpc.PaymentFailureReason_FAILURE_REASON_NO_ROUTE + p := ht.AssertPaymentFailureReason(alice, preimage, expectedReason) + require.Equal(ht, lnrpc.Failure_PERMANENT_CHANNEL_FAILURE, + p.Htlcs[0].Failure.Code) + + p = ht.AssertPaymentFailureReason(alice, preimageDust, expectedReason) + require.Equal(ht, lnrpc.Failure_PERMANENT_CHANNEL_FAILURE, + p.Htlcs[0].Failure.Code) } // testMultiHopReceiverPreimageClaimAnchor tests diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index 9499fa25a3..381121d556 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -1344,7 +1344,10 @@ func (s *Server) trackPayment(subscription routing.ControlTowerSubscriber, // Otherwise, we will log and return the error as the stream has // received an error from the payment lifecycle. - log.Errorf("TrackPayment got error for payment %v: %v", identifier, err) + if err != nil { + log.Errorf("TrackPayment got error for payment %v: %v", + identifier, err) + } return err } diff --git a/lntest/harness_assertion.go b/lntest/harness_assertion.go index d5835fe573..e687c8a902 100644 --- a/lntest/harness_assertion.go +++ b/lntest/harness_assertion.go @@ -1600,8 +1600,11 @@ func (h *HarnessTest) AssertPaymentStatus(hn *node.HarnessNode, // AssertPaymentFailureReason asserts that the given node lists a payment with // the given preimage which has the expected failure reason. -func (h *HarnessTest) AssertPaymentFailureReason(hn *node.HarnessNode, - preimage lntypes.Preimage, reason lnrpc.PaymentFailureReason) { +func (h *HarnessTest) AssertPaymentFailureReason( + hn *node.HarnessNode, preimage lntypes.Preimage, + reason lnrpc.PaymentFailureReason) *lnrpc.Payment { + + var payment *lnrpc.Payment payHash := preimage.Hash() err := wait.NoError(func() error { @@ -1610,14 +1613,19 @@ func (h *HarnessTest) AssertPaymentFailureReason(hn *node.HarnessNode, return err } + payment = p + if reason == p.FailureReason { return nil } return fmt.Errorf("payment: %v failure reason not match, "+ - "want %s got %s", payHash, reason, p.Status) + "want %s(%d) got %s(%d)", payHash, reason, reason, + p.FailureReason, p.FailureReason) }, DefaultTimeout) require.NoError(h, err, "timeout checking payment failure reason") + + return payment } // AssertActiveNodesSynced asserts all active nodes have synced to the chain. diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go index 5fb271afa6..532e639d6a 100644 --- a/routing/control_tower_test.go +++ b/routing/control_tower_test.go @@ -535,15 +535,23 @@ func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.HTLCAttemptInfo, } rhash := sha256.Sum256(preimage[:]) + var hash lntypes.Hash + copy(hash[:], rhash[:]) + + attempt, err := channeldb.NewHtlcAttempt( + 1, priv, testRoute, time.Time{}, &hash, + ) + if err != nil { + return nil, nil, lntypes.Preimage{}, err + } + return &channeldb.PaymentCreationInfo{ PaymentIdentifier: rhash, Value: testRoute.ReceiverAmt(), CreationTime: time.Unix(time.Now().Unix(), 0), PaymentRequest: []byte("hola"), }, - &channeldb.NewHtlcAttempt( - 1, priv, testRoute, time.Time{}, nil, - ).HTLCAttemptInfo, preimage, nil + &attempt.HTLCAttemptInfo, preimage, nil } func genPreimage() ([32]byte, error) { diff --git a/routing/payment_lifecycle.go b/routing/payment_lifecycle.go index 180d38a631..37af3b9f21 100644 --- a/routing/payment_lifecycle.go +++ b/routing/payment_lifecycle.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/shards" @@ -40,16 +41,19 @@ type paymentLifecycle struct { quit chan struct{} // resultCollected is used to signal that the result of an attempt has - // been collected. A nil error means the attempt is either successful - // or failed with temporary error. Otherwise, we should exit the - // lifecycle loop as a terminal error has occurred. - resultCollected chan error + // been collected. + resultCollected chan struct{} // resultCollector is a function that is used to collect the result of // an HTLC attempt, which is always mounted to `p.collectResultAsync` // except in unit test, where we use a much simpler resultCollector to // decouple the test flow for the payment lifecycle. resultCollector func(attempt *channeldb.HTLCAttempt) + + // switchResults is a map that holds the results for HTLC attempts + // returned from the htlcswitch. + switchResults lnutils.SyncMap[*channeldb.HTLCAttempt, + *htlcswitch.PaymentResult] } // newPaymentLifecycle initiates a new payment lifecycle and returns it. @@ -66,8 +70,10 @@ func newPaymentLifecycle(r *ChannelRouter, feeLimit lnwire.MilliSatoshi, shardTracker: shardTracker, currentHeight: currentHeight, quit: make(chan struct{}), - resultCollected: make(chan error, 1), + resultCollected: make(chan struct{}, 1), firstHopCustomRecords: firstHopCustomRecords, + switchResults: lnutils.SyncMap[*channeldb.HTLCAttempt, + *htlcswitch.PaymentResult]{}, } // Mount the result collector. @@ -143,12 +149,7 @@ func (p *paymentLifecycle) decideNextStep( // NOTE: we don't check `p.quit` since `decideNextStep` is // running in the same goroutine as `resumePayment`. select { - case err := <-p.resultCollected: - // If an error is returned, exit with it. - if err != nil { - return stepExit, err - } - + case <-p.resultCollected: log.Tracef("Received attempt result for payment %v", p.identifier) @@ -175,20 +176,11 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, // If we had any existing attempts outstanding, we'll start by spinning // up goroutines that'll collect their results and deliver them to the // lifecycle loop below. - payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + payment, err := p.reloadInflightAttempts() if err != nil { return [32]byte{}, nil, err } - for _, a := range payment.InFlightHTLCs() { - a := a - - log.Infof("Resuming HTLC attempt %v for payment %v", - a.AttemptID, p.identifier) - - p.resultCollector(&a) - } - // Get the payment status. status := payment.GetStatus() @@ -211,23 +203,14 @@ func (p *paymentLifecycle) resumePayment(ctx context.Context) ([32]byte, // critical error during path finding. lifecycle: for { - // We update the payment state on every iteration. Since the - // payment state is affected by multiple goroutines (ie, - // collectResultAsync), it is NOT guaranteed that we always - // have the latest state here. This is fine as long as the - // state is consistent as a whole. - payment, err = p.router.cfg.Control.FetchPayment(p.identifier) + // We update the payment state on every iteration. + currentPayment, ps, err := p.processResultsAndReloadPayment() if err != nil { return exitWithErr(err) } - ps := payment.GetState() - remainingFees := p.calcFeeBudget(ps.FeesPaid) - - status = payment.GetStatus() - log.Debugf("Payment %v: status=%v, active_shards=%v, "+ - "rem_value=%v, fee_limit=%v", p.identifier, status, - ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) + status = currentPayment.GetStatus() + payment = currentPayment // We now proceed our lifecycle with the following tasks in // order, @@ -288,16 +271,19 @@ lifecycle: log.Tracef("Found route: %s", spew.Sdump(rt.Hops)) - // Allow the traffic shaper to add custom records to the - // outgoing HTLC and also adjust the amount if needed. - err = p.amendFirstHopData(rt) - if err != nil { - return exitWithErr(err) - } - // We found a route to try, create a new HTLC attempt to try. attempt, err := p.registerAttempt(rt, ps.RemainingAmt) if err != nil { + // If the error is due to we cannot register another + // HTLC, we will skip this iteration and continue to + // the next one in case there are inflight HTLCs. + // + // TODO(yy): remove this check once we have a finer + // control over errors returned from the switch. + if errors.Is(err, channeldb.ErrRegisterAttempt) { + continue lifecycle + } + return exitWithErr(err) } @@ -391,6 +377,13 @@ func (p *paymentLifecycle) requestRoute( // Exit early if there's no error. if err == nil { + // Allow the traffic shaper to add custom records to the + // outgoing HTLC and also adjust the amount if needed. + err = p.amendFirstHopData(rt) + if err != nil { + return nil, err + } + return rt, nil } @@ -444,62 +437,55 @@ type attemptResult struct { } // collectResultAsync launches a goroutine that will wait for the result of the -// given HTLC attempt to be available then handle its result. Once received, it -// will send a nil error to channel `resultCollected` to indicate there's a -// result. +// given HTLC attempt to be available then save its result in a map. Once +// received, it will send a signal to channel `resultCollected` to indicate +// there's a result. func (p *paymentLifecycle) collectResultAsync(attempt *channeldb.HTLCAttempt) { log.Debugf("Collecting result for attempt %v in payment %v", attempt.AttemptID, p.identifier) go func() { - // Block until the result is available. - _, err := p.collectResult(attempt) + result, err := p.collectResult(attempt) if err != nil { - log.Errorf("Error collecting result for attempt %v "+ - "in payment %v: %v", attempt.AttemptID, + log.Errorf("Error collecting result for attempt %v in "+ + "payment %v: %v", attempt.AttemptID, p.identifier, err) + + return } log.Debugf("Result collected for attempt %v in payment %v", attempt.AttemptID, p.identifier) - // Once the result is collected, we signal it by writing the - // error to `resultCollected`. + // Save the result and process it in the next main loop. + p.switchResults.Store(attempt, result) + + // Signal that a result has been collected. select { // Send the signal or quit. - case p.resultCollected <- err: + case p.resultCollected <- struct{}{}: case <-p.quit: log.Debugf("Lifecycle exiting while collecting "+ "result for payment %v", p.identifier) case <-p.router.quit: - return } }() } -// collectResult waits for the result for the given attempt to be available -// from the Switch, then records the attempt outcome with the control tower. -// An attemptResult is returned, indicating the final outcome of this HTLC -// attempt. -func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( - *attemptResult, error) { +// collectResult waits for the result of the given HTLC attempt to be sent by +// the switch and returns it. +func (p *paymentLifecycle) collectResult( + attempt *channeldb.HTLCAttempt) (*htlcswitch.PaymentResult, error) { log.Tracef("Collecting result for attempt %v", spew.Sdump(attempt)) - // We'll retrieve the hash specific to this shard from the - // shardTracker, since it will be needed to regenerate the circuit - // below. - hash, err := p.shardTracker.GetHash(attempt.AttemptID) - if err != nil { - return p.failAttempt(attempt.AttemptID, err) - } + result := &htlcswitch.PaymentResult{} // Regenerate the circuit for this attempt. - _, circuit, err := generateSphinxPacket( - &attempt.Route, hash[:], attempt.SessionKey(), - ) + circuit, err := attempt.Circuit() + // TODO(yy): We generate this circuit to create the error decryptor, // which is then used in htlcswitch as the deobfuscator to decode the // error from `UpdateFailHTLC`. However, suppose it's an @@ -512,8 +498,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( if err != nil { log.Debugf("Unable to generate circuit for attempt %v: %v", attempt.AttemptID, err) - - return p.failAttempt(attempt.AttemptID, err) + return nil, err } // Using the created circuit, initialize the error decrypter, so we can @@ -539,22 +524,21 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( log.Errorf("Failed getting result for attemptID %d "+ "from switch: %v", attempt.AttemptID, err) - return p.handleSwitchErr(attempt, err) + result.Error = err + + return result, nil } // The switch knows about this payment, we'll wait for a result to be // available. - var ( - result *htlcswitch.PaymentResult - ok bool - ) - select { - case result, ok = <-resultChan: + case r, ok := <-resultChan: if !ok { return nil, htlcswitch.ErrSwitchExiting } + result = r + case <-p.quit: return nil, ErrPaymentLifecycleExiting @@ -562,46 +546,7 @@ func (p *paymentLifecycle) collectResult(attempt *channeldb.HTLCAttempt) ( return nil, ErrRouterShuttingDown } - // In case of a payment failure, fail the attempt with the control - // tower and return. - if result.Error != nil { - return p.handleSwitchErr(attempt, result.Error) - } - - // We successfully got a payment result back from the switch. - log.Debugf("Payment %v succeeded with pid=%v", - p.identifier, attempt.AttemptID) - - // Report success to mission control. - err = p.router.cfg.MissionControl.ReportPaymentSuccess( - attempt.AttemptID, &attempt.Route, - ) - if err != nil { - log.Errorf("Error reporting payment success to mc: %v", err) - } - - // In case of success we atomically store settle result to the DB move - // the shard to the settled state. - htlcAttempt, err := p.router.cfg.Control.SettleAttempt( - p.identifier, attempt.AttemptID, - &channeldb.HTLCSettleInfo{ - Preimage: result.Preimage, - SettleTime: p.router.cfg.Clock.Now(), - }, - ) - if err != nil { - log.Errorf("Error settling attempt %v for payment %v with "+ - "preimage %v: %v", attempt.AttemptID, p.identifier, - result.Preimage, err) - - // We won't mark the attempt as failed since we already have - // the preimage. - return nil, err - } - - return &attemptResult{ - attempt: htlcAttempt, - }, nil + return result, nil } // registerAttempt is responsible for creating and saving an HTLC attempt in db @@ -675,11 +620,9 @@ func (p *paymentLifecycle) createNewPaymentAttempt(rt *route.Route, // We now have all the information needed to populate the current // attempt information. - attempt := channeldb.NewHtlcAttempt( + return channeldb.NewHtlcAttempt( attemptID, sessionKey, *rt, p.router.cfg.Clock.Now(), &hash, ) - - return attempt, nil } // sendAttempt attempts to send the current attempt to the switch to complete @@ -711,9 +654,7 @@ func (p *paymentLifecycle) sendAttempt( // Generate the raw encoded sphinx packet to be included along // with the htlcAdd message that we send directly to the // switch. - onionBlob, _, err := generateSphinxPacket( - &rt, attempt.Hash[:], attempt.SessionKey(), - ) + onionBlob, err := attempt.OnionBlob() if err != nil { log.Errorf("Failed to create onion blob: attempt=%d in "+ "payment=%v, err:%v", attempt.AttemptID, @@ -722,7 +663,7 @@ func (p *paymentLifecycle) sendAttempt( return p.failAttempt(attempt.AttemptID, err) } - copy(htlcAdd.OnionBlob[:], onionBlob) + htlcAdd.OnionBlob = onionBlob // Send it to the Switch. When this method returns we assume // the Switch successfully has persisted the payment attempt, @@ -885,8 +826,8 @@ func (p *paymentLifecycle) handleSwitchErr(attempt *channeldb.HTLCAttempt, // case we can safely send a new payment attempt, and wait for its // result to be available. if errors.Is(sendErr, htlcswitch.ErrPaymentIDNotFound) { - log.Debugf("Attempt ID %v for payment %v not found in the "+ - "Switch, retrying.", attempt.AttemptID, p.identifier) + log.Warnf("Failing attempt=%v for payment=%v as it's not "+ + "found in the Switch", attempt.AttemptID, p.identifier) return p.failAttempt(attemptID, sendErr) } @@ -1097,3 +1038,151 @@ func marshallError(sendError error, time time.Time) *channeldb.HTLCFailInfo { return response } + +// reloadInflightAttempts is called when the payment lifecycle is resumed after +// a restart. It reloads all inflight attempts from the control tower and +// collects the results of the attempts that have been sent before. +func (p *paymentLifecycle) reloadInflightAttempts() (DBMPPayment, error) { + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return nil, err + } + + for _, a := range payment.InFlightHTLCs() { + a := a + + log.Infof("Resuming HTLC attempt %v for payment %v", + a.AttemptID, p.identifier) + + p.resultCollector(&a) + } + + return payment, nil +} + +// processResultsAndReloadPayment returns the latest payment found in the db +// (control tower) after all its attempt results are processed. +func (p *paymentLifecycle) processResultsAndReloadPayment() (DBMPPayment, + *channeldb.MPPaymentState, error) { + + // Process the stored results first as they will affect the state of + // the payment. + if err := p.processSwitchResults(); err != nil { + return nil, nil, err + } + + // Read the db to get the latest state of the payment. + payment, err := p.router.cfg.Control.FetchPayment(p.identifier) + if err != nil { + return nil, nil, err + } + + ps := payment.GetState() + remainingFees := p.calcFeeBudget(ps.FeesPaid) + + log.Debugf("Payment %v: status=%v, active_shards=%v, rem_value=%v, "+ + "fee_limit=%v", p.identifier, payment.GetStatus(), + ps.NumAttemptsInFlight, ps.RemainingAmt, remainingFees) + + return payment, ps, nil +} + +// processSwitchResults reads the `p.results` map and process the results +// returned from the htlcswitch. +func (p *paymentLifecycle) processSwitchResults() error { + // Create a slice to remember the results of the attempts that we have + // processed. + attempts := make([]*channeldb.HTLCAttempt, 0, p.switchResults.Len()) + + var errReturned error + + // Range over the map to process all the results. + p.switchResults.Range(func(a *channeldb.HTLCAttempt, + result *htlcswitch.PaymentResult) bool { + + // Save the keys so we know which items to delete from the map. + attempts = append(attempts, a) + + // Handle the attempt result. If an error is returned here, it + // means the payment lifecycle needs to be terminated. + _, err := p.handleAttemptResult(a, result) + if err != nil { + log.Errorf("Error handling result for attempt=%v in "+ + "payment%v: %v", a.AttemptID, p.identifier, err) + + errReturned = err + } + + // Always return true so we will process all results. + return true + }) + + // Clean up the processed results. + for _, a := range attempts { + p.switchResults.Delete(a) + } + + return errReturned +} + +// handleAttemptResult processes the result of an HTLC attempt returned from +// the htlcswitch. +func (p *paymentLifecycle) handleAttemptResult(attempt *channeldb.HTLCAttempt, + result *htlcswitch.PaymentResult) (*attemptResult, error) { + + // If the result has an error, we need to further process it by failing + // the attempt and maybe fail the payment. + if result.Error != nil { + return p.handleSwitchErr(attempt, result.Error) + } + + // We got an attempt settled result back from the switch. + log.Debugf("Payment(%v): attempt(%v) succeeded", p.identifier, + attempt.AttemptID) + + // Report success to mission control. + err := p.router.cfg.MissionControl.ReportPaymentSuccess( + attempt.AttemptID, &attempt.Route, + ) + if err != nil { + log.Errorf("Error reporting payment success to mc: %v", err) + } + + // In case of success we atomically store settle result to the DB move + // the shard to the settled state. + htlcAttempt, err := p.router.cfg.Control.SettleAttempt( + p.identifier, attempt.AttemptID, + &channeldb.HTLCSettleInfo{ + Preimage: result.Preimage, + SettleTime: p.router.cfg.Clock.Now(), + }, + ) + if err != nil { + log.Errorf("Error settling attempt %v for payment %v with "+ + "preimage %v: %v", attempt.AttemptID, p.identifier, + result.Preimage, err) + + // We won't mark the attempt as failed since we already have + // the preimage. + return nil, err + } + + return &attemptResult{ + attempt: htlcAttempt, + }, nil +} + +// collectAndHandleResult waits for the result for the given attempt to be +// available from the Switch, then records the attempt outcome with the control +// tower. An attemptResult is returned, indicating the final outcome of this +// HTLC attempt. +func (p *paymentLifecycle) collectAndHandleResult( + attempt *channeldb.HTLCAttempt) (*attemptResult, error) { + + result, err := p.collectResult(attempt) + if err != nil { + return nil, err + } + + return p.handleAttemptResult(attempt, result) +} diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 72aa631419..6052761096 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -260,10 +260,15 @@ func createDummyRoute(t *testing.T, amt lnwire.MilliSatoshi) *route.Route { func makeSettledAttempt(t *testing.T, total int, preimage lntypes.Preimage) *channeldb.HTLCAttempt { - return &channeldb.HTLCAttempt{ + a := &channeldb.HTLCAttempt{ HTLCAttemptInfo: makeAttemptInfo(t, total), Settle: &channeldb.HTLCSettleInfo{Preimage: preimage}, } + + hash := preimage.Hash() + a.Hash = &hash + + return a } func makeFailedAttempt(t *testing.T, total int) *channeldb.HTLCAttempt { @@ -279,6 +284,7 @@ func makeAttemptInfo(t *testing.T, amt int) channeldb.HTLCAttemptInfo { rt := createDummyRoute(t, lnwire.MilliSatoshi(amt)) return channeldb.HTLCAttemptInfo{ Route: *rt, + Hash: &lntypes.Hash{1, 2, 3}, } } @@ -611,8 +617,8 @@ func TestDecideNextStep(t *testing.T) { // Send a nil error to the attemptResultChan if requested. if tc.closeResultChan { - p.resultCollected = make(chan error, 1) - p.resultCollected <- nil + p.resultCollected = make(chan struct{}, 1) + p.resultCollected <- struct{}{} } // Quit the router if requested. @@ -1303,11 +1309,6 @@ func TestCollectResultExitOnErr(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a dummy error. m.payer.On("GetAttemptResult", attempt.AttemptID, p.identifier, mock.Anything, @@ -1332,7 +1333,7 @@ func TestCollectResultExitOnErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1348,11 +1349,6 @@ func TestCollectResultExitOnResultErr(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1383,7 +1379,7 @@ func TestCollectResultExitOnResultErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, errDummy, "expected dummy error") require.Nil(t, result, "expected nil attempt") } @@ -1399,11 +1395,6 @@ func TestCollectResultExitOnSwitchQuit(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1414,7 +1405,7 @@ func TestCollectResultExitOnSwitchQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, htlcswitch.ErrSwitchExiting, "expected switch exit") require.Nil(t, result, "expected nil attempt") @@ -1431,11 +1422,6 @@ func TestCollectResultExitOnRouterQuit(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1446,7 +1432,7 @@ func TestCollectResultExitOnRouterQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, ErrRouterShuttingDown, "expected router exit") require.Nil(t, result, "expected nil attempt") } @@ -1462,11 +1448,6 @@ func TestCollectResultExitOnLifecycleQuit(t *testing.T) { paymentAmt := 10_000 attempt := makeFailedAttempt(t, paymentAmt) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1477,7 +1458,7 @@ func TestCollectResultExitOnLifecycleQuit(t *testing.T) { }) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, ErrPaymentLifecycleExiting, "expected lifecycle exit") require.Nil(t, result, "expected nil attempt") @@ -1495,11 +1476,6 @@ func TestCollectResultExitOnSettleErr(t *testing.T) { preimage := lntypes.Preimage{1} attempt := makeSettledAttempt(t, paymentAmt, preimage) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1526,7 +1502,7 @@ func TestCollectResultExitOnSettleErr(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.ErrorIs(t, err, errDummy, "expected settle error") require.Nil(t, result, "expected nil attempt") } @@ -1542,11 +1518,6 @@ func TestCollectResultSuccess(t *testing.T) { preimage := lntypes.Preimage{1} attempt := makeSettledAttempt(t, paymentAmt, preimage) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1573,7 +1544,7 @@ func TestCollectResultSuccess(t *testing.T) { m.clock.On("Now").Return(time.Now()) // Now call the method under test. - result, err := p.collectResult(attempt) + result, err := p.collectAndHandleResult(attempt) require.NoError(t, err, "expected no error") require.Equal(t, preimage, result.attempt.Settle.Preimage, "preimage mismatch") @@ -1590,11 +1561,6 @@ func TestCollectResultAsyncSuccess(t *testing.T) { preimage := lntypes.Preimage{1} attempt := makeSettledAttempt(t, paymentAmt, preimage) - // Mock shardTracker to return the payment hash. - m.shardTracker.On("GetHash", - attempt.AttemptID, - ).Return(p.identifier, nil).Once() - // Mock the htlcswitch to return a the result chan. resultChan := make(chan *htlcswitch.PaymentResult, 1) m.payer.On("GetAttemptResult", @@ -1606,8 +1572,80 @@ func TestCollectResultAsyncSuccess(t *testing.T) { } }) - // Once the result is received, `ReportPaymentSuccess` should be - // called. + // Now call the method under test. + p.collectResultAsync(attempt) + + // Assert the result is returned within 5 seconds. + waitErr := wait.NoError(func() error { + <-p.resultCollected + return nil + }, testTimeout) + require.NoError(t, waitErr, "timeout waiting for result") + + // Assert the result is saved in the map. + p.switchResults.Load(attempt) +} + +// TestHandleAttemptResultWithError checks that when the `Error` field in the +// result is not nil, it's properly handled by `handleAttemptResult`. +func TestHandleAttemptResultWithError(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains an error. + // + // NOTE: The error is chosen so we can quickly exit `handleSwitchErr` + // since we are not testing its behavior here. + result := &htlcswitch.PaymentResult{ + Error: htlcswitch.ErrPaymentIDNotFound, + } + + // The above error will end up being handled by `handleSwitchErr`, in + // which we'd cancel the shard and fail the attempt. + // + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a dummy error. + m.control.On("FailAttempt", + p.identifier, attempt.AttemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Call the method under test and expect the dummy error to be + // returned. + attemptResult, err := p.handleAttemptResult(attempt, result) + require.ErrorIs(t, err, errDummy, "expected fail error") + require.Nil(t, attemptResult, "expected nil attempt result") +} + +// TestHandleAttemptResultSuccess checks that when the result contains no error +// but a preimage, it's handled correctly by `handleAttemptResult`. +func TestHandleAttemptResultSuccess(t *testing.T) { + t.Parallel() + + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains a preimage. + result := &htlcswitch.PaymentResult{ + Preimage: preimage, + } + + // Since the result doesn't contain an error, `ReportPaymentSuccess` + // should be called. m.missionControl.On("ReportPaymentSuccess", attempt.AttemptID, &attempt.Route, ).Return(nil).Once() @@ -1620,17 +1658,74 @@ func TestCollectResultAsyncSuccess(t *testing.T) { // Mock the clock to return a current time. m.clock.On("Now").Return(time.Now()) - // Now call the method under test. - p.collectResultAsync(attempt) + // Call the method under test and expect the dummy error to be + // returned. + attemptResult, err := p.handleAttemptResult(attempt, result) + require.NoError(t, err, "expected no error") + require.Equal(t, attempt, attemptResult.attempt) +} - // Assert the result is returned within 5 seconds. - var err error - waitErr := wait.NoError(func() error { - err = <-p.resultCollected - return nil - }, testTimeout) - require.NoError(t, waitErr, "timeout waiting for result") +// TestProcessSwitchResults checks that `processSwitchResults` will update the +// map `switchResults` as expected and handle the results correctly. +func TestProcessSwitchResults(t *testing.T) { + t.Parallel() - // Assert that a nil error is received. - require.NoError(t, err, "expected no error") + // Create a test paymentLifecycle with the initial two calls mocked. + p, m := newTestPaymentLifecycle(t) + + paymentAmt := 10_000 + preimage := lntypes.Preimage{1} + attempt1 := makeSettledAttempt(t, paymentAmt, preimage) + attempt2 := makeSettledAttempt(t, paymentAmt, preimage) + + // Create a result that contains an error. + // + // NOTE: The error is chosen so we can quickly exit `handleSwitchErr` + // since we are not testing its behavior here. + result1 := &htlcswitch.PaymentResult{ + Error: htlcswitch.ErrPaymentIDNotFound, + } + + // Create a result that contains a preimage. + result2 := &htlcswitch.PaymentResult{ + Preimage: preimage, + } + + // Save the results to the map. + p.switchResults.Store(attempt1, result1) + p.switchResults.Store(attempt2, result2) + + // Since result1 contains an error, it will end up being handled by + // `handleAttemptResult`, in which we'd cancel the shard and fail the + // attempt. + // + // `CancelShard` should be called with the attemptID. + m.shardTracker.On("CancelShard", attempt1.AttemptID).Return(nil).Once() + + // Mock `FailAttempt` to return a dummy error. + m.control.On("FailAttempt", + p.identifier, attempt1.AttemptID, mock.Anything, + ).Return(nil, errDummy).Once() + + // Mock the clock to return a current time. + m.clock.On("Now").Return(time.Now()) + + // Since result2 doesn't contain an error, `ReportPaymentSuccess` + // should be called. + m.missionControl.On("ReportPaymentSuccess", + attempt2.AttemptID, &attempt2.Route, + ).Return(nil).Once() + + // Now the settled htlc being returned from `SettleAttempt`. + m.control.On("SettleAttempt", + p.identifier, attempt2.AttemptID, mock.Anything, + ).Return(attempt2, nil).Once() + + // Call the method under test and expect the dummy error to be + // returned from processing result1. + err := p.processSwitchResults() + require.ErrorIs(t, err, errDummy, "expected fail error") + + // Assert the map is cleaned. + require.Zero(t, p.switchResults.Len(), "expected no results in map") } diff --git a/routing/router.go b/routing/router.go index 468510a6c7..323fe76167 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1,7 +1,6 @@ package routing import ( - "bytes" "context" "fmt" "math" @@ -15,7 +14,6 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" - sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/amp" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" @@ -722,71 +720,6 @@ func generateNewSessionKey() (*btcec.PrivateKey, error) { return btcec.NewPrivateKey() } -// generateSphinxPacket generates then encodes a sphinx packet which encodes -// the onion route specified by the passed layer 3 route. The blob returned -// from this function can immediately be included within an HTLC add packet to -// be sent to the first hop within the route. -func generateSphinxPacket(rt *route.Route, paymentHash []byte, - sessionKey *btcec.PrivateKey) ([]byte, *sphinx.Circuit, error) { - - // Now that we know we have an actual route, we'll map the route into a - // sphinx payment path which includes per-hop payloads for each hop - // that give each node within the route the necessary information - // (fees, CLTV value, etc.) to properly forward the payment. - sphinxPath, err := rt.ToSphinxPath() - if err != nil { - return nil, nil, err - } - - log.Tracef("Constructed per-hop payloads for payment_hash=%x: %v", - paymentHash, lnutils.NewLogClosure(func() string { - path := make( - []sphinx.OnionHop, sphinxPath.TrueRouteLength(), - ) - for i := range path { - hopCopy := sphinxPath[i] - path[i] = hopCopy - } - - return spew.Sdump(path) - }), - ) - - // Next generate the onion routing packet which allows us to perform - // privacy preserving source routing across the network. - sphinxPacket, err := sphinx.NewOnionPacket( - sphinxPath, sessionKey, paymentHash, - sphinx.DeterministicPacketFiller, - ) - if err != nil { - return nil, nil, err - } - - // Finally, encode Sphinx packet using its wire representation to be - // included within the HTLC add packet. - var onionBlob bytes.Buffer - if err := sphinxPacket.Encode(&onionBlob); err != nil { - return nil, nil, err - } - - log.Tracef("Generated sphinx packet: %v", - lnutils.NewLogClosure(func() string { - // We make a copy of the ephemeral key and unset the - // internal curve here in order to keep the logs from - // getting noisy. - key := *sphinxPacket.EphemeralKey - packetCopy := *sphinxPacket - packetCopy.EphemeralKey = &key - return spew.Sdump(packetCopy) - }), - ) - - return onionBlob.Bytes(), &sphinx.Circuit{ - SessionKey: sessionKey, - PaymentPath: sphinxPath.NodeKeys(), - }, nil -} - // LightningPayment describes a payment to be sent through the network to the // final destination. type LightningPayment struct { @@ -1253,7 +1186,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // The attempt was successfully sent, wait for the result to be // available. - result, err = p.collectResult(attempt) + result, err = p.collectAndHandleResult(attempt) if err != nil { return nil, err } diff --git a/routing/router_test.go b/routing/router_test.go index 22c9d14e50..543f3b0005 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -48,13 +48,6 @@ var ( testFeatures = lnwire.NewFeatureVector(nil, lnwire.Features) - testHash = [32]byte{ - 0xb7, 0x94, 0x38, 0x5f, 0x2d, 0x1e, 0xf7, 0xab, - 0x4d, 0x92, 0x73, 0xd1, 0x90, 0x63, 0x81, 0xb4, - 0x4f, 0x2f, 0x6f, 0x25, 0x88, 0xa3, 0xef, 0xb9, - 0x6a, 0x49, 0x18, 0x83, 0x31, 0x98, 0x47, 0x53, - } - testTime = time.Date(2018, time.January, 9, 14, 00, 00, 0, time.UTC) priv1, _ = btcec.NewPrivateKey() @@ -1235,18 +1228,6 @@ func TestFindPathFeeWeighting(t *testing.T) { require.Equal(t, ctx.aliases["luoji"], path[0].policy.ToNodePubKey()) } -// TestEmptyRoutesGenerateSphinxPacket tests that the generateSphinxPacket -// function is able to gracefully handle being passed a nil set of hops for the -// route by the caller. -func TestEmptyRoutesGenerateSphinxPacket(t *testing.T) { - t.Parallel() - - sessionKey, _ := btcec.NewPrivateKey() - emptyRoute := &route.Route{} - _, _, err := generateSphinxPacket(emptyRoute, testHash[:], sessionKey) - require.ErrorIs(t, err, route.ErrNoRouteHopsProvided) -} - // TestUnknownErrorSource tests that if the source of an error is unknown, all // edges along the route will be pruned. func TestUnknownErrorSource(t *testing.T) {