diff --git a/rfq/order.go b/rfq/order.go index 035bd91fa..135977a86 100644 --- a/rfq/order.go +++ b/rfq/order.go @@ -71,6 +71,14 @@ type Policy interface { // which the policy applies. Scid() uint64 + // TrackAcceptedHtlc makes the policy aware of this new accepted HTLC. + // This is important in cases where the set of existing HTLCs may affect + // whether the next compliance check passes. + TrackAcceptedHtlc(htlc lndclient.InterceptedHtlc) + + // UntrackHtlc stops tracking the uniquely identified htlc. + UntrackHtlc(htlcID string) + // GenerateInterceptorResponse generates an interceptor response for the // HTLC interceptor from the policy. GenerateInterceptorResponse( @@ -95,9 +103,17 @@ type AssetSalePolicy struct { // the policy. MaxOutboundAssetAmount uint64 + // CurrentAssetAmountMsat is the total amount that is held currently in + // accepted htlcs. + CurrentAmountMsat lnwire.MilliSatoshi + // AskAssetRate is the quote's asking asset unit to BTC conversion rate. AskAssetRate rfqmath.BigIntFixedPoint + // htlcToAmt maps the unique htlc identifiers to the effective amount + // that they carry. + htlcToAmt lnutils.SyncMap[string, lnwire.MilliSatoshi] + // expiry is the policy's expiry unix timestamp after which the policy // is no longer valid. expiry uint64 @@ -152,7 +168,8 @@ func (c *AssetSalePolicy) CheckHtlcCompliance( maxAssetAmount, c.AskAssetRate, ) - if htlc.AmountOutMsat > policyMaxOutMsat { + if (c.CurrentAmountMsat + htlc.AmountOutMsat) > policyMaxOutMsat { + // if htlc.AmountOutMsat > policyMaxOutMsat { return fmt.Errorf("htlc out amount is greater than the policy "+ "maximum (htlc_out_msat=%d, policy_max_out_msat=%d)", htlc.AmountOutMsat, policyMaxOutMsat) @@ -167,6 +184,29 @@ func (c *AssetSalePolicy) CheckHtlcCompliance( return nil } +// TrackAcceptedHtlc accounts for the newly accepted htlc. This may affect the +// acceptance of future htlcs. +func (c *AssetSalePolicy) TrackAcceptedHtlc(htlc lndclient.InterceptedHtlc) { + c.CurrentAmountMsat += htlc.AmountOutMsat + + htlcIDStr := htlcIdentifierStr( + htlc.IncomingCircuitKey.ChanID.ToUint64(), + htlc.IncomingCircuitKey.HtlcID, + ) + + c.htlcToAmt.Store(htlcIDStr, htlc.AmountOutMsat) +} + +// UntrackHtlc stops tracking the uniquely identified htlc. +func (c *AssetSalePolicy) UntrackHtlc(htlcIDStr string) { + amt, found := c.htlcToAmt.LoadAndDelete(htlcIDStr) + if !found { + return + } + + c.CurrentAmountMsat -= amt +} + // Expiry returns the policy's expiry time as a unix timestamp. func (c *AssetSalePolicy) Expiry() uint64 { return c.expiry @@ -246,12 +286,20 @@ type AssetPurchasePolicy struct { // AcceptedQuoteId is the ID of the accepted quote. AcceptedQuoteId rfqmsg.ID + // CurrentAssetAmountMsat is the total amount that is held currently in + // accepted htlcs. + CurrentAmountMsat lnwire.MilliSatoshi + // BidAssetRate is the quote's asset to BTC conversion rate. BidAssetRate rfqmath.BigIntFixedPoint // PaymentMaxAmt is the maximum agreed BTC payment. PaymentMaxAmt lnwire.MilliSatoshi + // htlcToAmt maps the unique htlc identifiers to the effective amount + // that they carry. + htlcToAmt lnutils.SyncMap[string, lnwire.MilliSatoshi] + // expiry is the policy's expiry unix timestamp in seconds after which // the policy is no longer valid. expiry uint64 @@ -322,7 +370,7 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance( // Ensure that the outbound HTLC amount is less than the maximum agreed // BTC payment. - if htlc.AmountOutMsat > c.PaymentMaxAmt { + if (c.CurrentAmountMsat + htlc.AmountOutMsat) > c.PaymentMaxAmt { return fmt.Errorf("htlc out amount is more than the maximum "+ "agreed BTC payment (htlc_out_msat=%d, "+ "payment_max_amt=%d)", htlc.AmountOutMsat, @@ -338,6 +386,31 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance( return nil } +// TrackAcceptedHtlc accounts for the newly accepted htlc. This may affect the +// acceptance of future htlcs. +func (c *AssetPurchasePolicy) TrackAcceptedHtlc( + htlc lndclient.InterceptedHtlc) { + + c.CurrentAmountMsat += htlc.AmountOutMsat + + htlcIDStr := htlcIdentifierStr( + htlc.IncomingCircuitKey.ChanID.ToUint64(), + htlc.IncomingCircuitKey.HtlcID, + ) + + c.htlcToAmt.Store(htlcIDStr, htlc.AmountOutMsat) +} + +// UntrackHtlc stops tracking the uniquely identified htlc. +func (c *AssetPurchasePolicy) UntrackHtlc(htlcIDStr string) { + amt, found := c.htlcToAmt.LoadAndDelete(htlcIDStr) + if !found { + return + } + + c.CurrentAmountMsat -= amt +} + // Expiry returns the policy's expiry time as a unix timestamp in seconds. func (c *AssetPurchasePolicy) Expiry() uint64 { return c.expiry @@ -436,6 +509,25 @@ func (a *AssetForwardPolicy) CheckHtlcCompliance( return nil } +// TrackAcceptedHtlc accounts for the newly accepted htlc. This may affect the +// acceptance of future htlcs. +func (a *AssetForwardPolicy) TrackAcceptedHtlc(htlc lndclient.InterceptedHtlc) { + // Track accepted htlc in the incoming policy. + a.incomingPolicy.TrackAcceptedHtlc(htlc) + + // Track accepted htlc in the outgoing policy. + a.outgoingPolicy.TrackAcceptedHtlc(htlc) +} + +// UntrackHtlc stops tracking the uniquely identified htlc. +func (a *AssetForwardPolicy) UntrackHtlc(htlcIDStr string) { + // Untrack htlc in the incoming policy. + a.incomingPolicy.UntrackHtlc(htlcIDStr) + + // Untrack htlc in the outgoing policy. + a.outgoingPolicy.UntrackHtlc(htlcIDStr) +} + // Expiry returns the policy's expiry time as a unix timestamp in seconds. The // returned expiry time is the earliest expiry time of the incoming and outgoing // policies. @@ -514,6 +606,10 @@ type OrderHandlerCfg struct { // AcceptHtlcEvents is a channel that receives accepted HTLCs. AcceptHtlcEvents chan<- *AcceptHtlcEvent + + // HtlcSubscriber is a subscriber that is used to retrieve live HTLC + // event updates. + HtlcSubscriber HtlcSubscriber } // OrderHandler orchestrates management of accepted quote bundles. It monitors @@ -530,6 +626,12 @@ type OrderHandler struct { // associated asset transaction policies. policies lnutils.SyncMap[SerialisedScid, Policy] + // htlcToPolicy maps a unique htlc identifier encoded as a string, to + // the policy that applies to it. We need this map because for failed + // HTLCs we don't have the RFQ data available, so we need to cache this + // info. + htlcToPolicy lnutils.SyncMap[string, Policy] + // ContextGuard provides a wait group and main quit channel that can be // used to create guarded contexts. *fn.ContextGuard @@ -593,6 +695,17 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context, }, nil } + htlcIDStr := htlcIdentifierStr( + htlc.IncomingCircuitKey.ChanID.ToUint64(), + htlc.IncomingCircuitKey.HtlcID, + ) + + h.htlcToPolicy.Store(htlcIDStr, policy) + + // The htlc passed the compliance checks, so now we keep track of the + // accepted htlc. + policy.TrackAcceptedHtlc(htlc) + log.Debug("HTLC complies with policy. Broadcasting accept event.") h.cfg.AcceptHtlcEvents <- NewAcceptHtlcEvent(htlc, policy) @@ -640,12 +753,64 @@ func (h *OrderHandler) mainEventLoop() { } } +// subscribeHtlcs subscribes the OrderHandler to HTLC events provided by the lnd +// RPC interface. We use this subscription to track HTLC forwarding failures, +// which we use to performn a live update of our policies. +func (h *OrderHandler) subscribeHtlcs(ctx context.Context) error { + events, chErr, err := h.cfg.HtlcSubscriber.SubscribeHtlcEvents(ctx) + if err != nil { + return err + } + + for { + select { + case event := <-events: + // We only care about forwarding events. + if event.GetEventType() != routerrpc.HtlcEvent_FORWARD { + continue + } + + // Retrieve the two instances that may be relevant. + failEvent := event.GetForwardFailEvent() + linkFail := event.GetLinkFailEvent() + + // Craft the string representation of the unique htlc + // identifier. This is later on used to map to an rfq + // policy. + htlcIDStr := htlcIdentifierStr( + event.IncomingChannelId, event.IncomingHtlcId, + ) + + switch { + case failEvent != nil: + fallthrough + case linkFail != nil: + // Fetch the policy that is related to this + // htlc. + policy, found := + h.htlcToPolicy.LoadAndDelete(htlcIDStr) + + if !found { + continue + } + + // Stop tracking this htlc as it failed. + policy.UntrackHtlc(htlcIDStr) + } + + case err := <-chErr: + return err + + case <-ctx.Done(): + return ctx.Err() + } + } +} + // Start starts the service. func (h *OrderHandler) Start() error { var startErr error h.startOnce.Do(func() { - log.Info("Starting subsystem: order handler") - // Start the main event loop in a separate goroutine. h.Wg.Add(1) go func() { @@ -663,6 +828,20 @@ func (h *OrderHandler) Start() error { h.mainEventLoop() }() + + // Start the HTLC event subscription loop. + h.Wg.Add(1) + go func() { + defer h.Wg.Done() + + ctx, cancel := h.WithCtxQuitNoTimeout() + defer cancel() + + err := h.subscribeHtlcs(ctx) + if err != nil { + log.Errorf("htlc subscriber error: %v", err) + } + }() }) return startErr @@ -851,3 +1030,9 @@ type HtlcSubscriber interface { SubscribeHtlcEvents(ctx context.Context) (<-chan *routerrpc.HtlcEvent, <-chan error, error) } + +// htlcIdentifierStr is a deterministic method that blends the chanID and htlcID +// of an in-flight HTLC to create a string that uniquely identifies it. +func htlcIdentifierStr(chanID, htlcID uint64) string { + return fmt.Sprintf("%v:%v", chanID, htlcID) +}