From ac8ce8270171da3f0d73c06a437fd78ea9461dc1 Mon Sep 17 00:00:00 2001 From: yyforyongyu Date: Mon, 1 Jul 2024 18:14:10 +0800 Subject: [PATCH] contractcourt: fix race access to `c.activeResolvers` --- contractcourt/channel_arbitrator.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 1d2268afe03..86418a6414a 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -811,11 +811,8 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // Report returns htlc reports for the active resolvers. func (c *ChannelArbitrator) Report() []*ContractReport { - c.activeResolversLock.RLock() - defer c.activeResolversLock.RUnlock() - var reports []*ContractReport - for _, resolver := range c.activeResolvers { + for _, resolver := range c.resolvers() { r, ok := resolver.(reportingContractResolver) if !ok { continue @@ -1569,6 +1566,7 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, // resolveContracts updates the activeResolvers list and starts to resolve each // contract concurrently, and launches them. func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) { + // Update the active contract resolvers. c.activeResolversLock.Lock() c.activeResolvers = resolvers c.activeResolversLock.Unlock() @@ -1576,7 +1574,7 @@ func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) { // Launch all resolvers. c.launchResolvers() - for _, contract := range resolvers { + for _, contract := range c.resolvers() { c.wg.Add(1) go c.resolveContract(contract) } @@ -1584,11 +1582,7 @@ func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) { // launchResolvers launches all the active resolvers. func (c *ChannelArbitrator) launchResolvers() { - c.activeResolversLock.Lock() - resolvers := c.activeResolvers - c.activeResolversLock.Unlock() - - for _, contract := range resolvers { + for _, contract := range c.resolvers() { // If the contract is already resolved, there's no need to // launch it again. if contract.IsResolved() { @@ -3426,3 +3420,14 @@ func (c *ChannelArbitrator) abandonForwards(htlcs fn.Set[uint64]) error { return nil } + +// resolvers returns a copy of the active resolvers. +func (c *ChannelArbitrator) resolvers() []ContractResolver { + c.activeResolversLock.Lock() + defer c.activeResolversLock.Unlock() + + resolvers := make([]ContractResolver, 0, len(c.activeResolvers)) + resolvers = append(resolvers, c.activeResolvers...) + + return resolvers +}