From 072edd56b3fd3cc71f19518b720beaff860cddb9 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 19 Dec 2023 11:58:31 -0600 Subject: [PATCH] Fix re-entrant `GetOrHandshake` issues (#1044) --- connection_manager.go | 13 +++++-- connection_manager_test.go | 5 ++- examples/config.yml | 4 ++ handshake_manager.go | 7 ++-- hostmap.go | 2 +- inside.go | 2 +- lighthouse.go | 75 ++++++++++++++++++++++++++------------ ssh.go | 2 +- 8 files changed, 74 insertions(+), 36 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index a1897566a..f5dd5942e 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -23,6 +23,7 @@ const ( swapPrimary trafficDecision = 3 migrateRelays trafficDecision = 4 tryRehandshake trafficDecision = 5 + sendTestPacket trafficDecision = 6 ) type connectionManager struct { @@ -176,7 +177,7 @@ func (n *connectionManager) Run(ctx context.Context) { } func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, now time.Time) { - decision, hostinfo, primary := n.makeTrafficDecision(localIndex, p, nb, out, now) + decision, hostinfo, primary := n.makeTrafficDecision(localIndex, now) switch decision { case deleteTunnel: @@ -197,6 +198,9 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case tryRehandshake: n.tryRehandshake(hostinfo) + + case sendTestPacket: + n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) } n.resetRelayTrafficCheck(hostinfo) @@ -289,7 +293,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) } } -func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out []byte, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { +func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time) (trafficDecision, *HostInfo, *HostInfo) { n.hostMap.RLock() defer n.hostMap.RUnlock() @@ -356,6 +360,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out [] return deleteTunnel, hostinfo, nil } + decision := doNothing if hostinfo != nil && hostinfo.ConnectionState != nil && mainHostInfo { if !outTraffic { // If we aren't sending or receiving traffic then its an unused tunnel and we don't to test the tunnel. @@ -380,7 +385,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out [] } // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - n.intf.SendMessageToHostInfo(header.Test, header.TestRequest, hostinfo, p, nb, out) + decision = sendTestPacket } else { if n.l.Level >= logrus.DebugLevel { @@ -390,7 +395,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, p, nb, out [] n.pendingDeletion[hostinfo.localIndexId] = struct{}{} n.trafficTimer.Add(hostinfo.localIndexId, n.pendingDeletionInterval) - return doNothing, nil, nil + return decision, hostinfo, nil } func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { diff --git a/connection_manager_test.go b/connection_manager_test.go index 5bc3f6f5c..a2607a2b7 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -21,8 +21,9 @@ var vpnIp iputil.VpnIp func newTestLighthouse() *LightHouse { lh := &LightHouse{ - l: test.NewLogger(), - addrMap: map[iputil.VpnIp]*RemoteList{}, + l: test.NewLogger(), + addrMap: map[iputil.VpnIp]*RemoteList{}, + queryChan: make(chan iputil.VpnIp, 10), } lighthouses := map[iputil.VpnIp]struct{}{} staticList := map[iputil.VpnIp]struct{}{} diff --git a/examples/config.yml b/examples/config.yml index c0ac0f6b2..c0969e115 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -289,6 +289,10 @@ logging: # A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out #try_interval: 100ms #retries: 20 + + # query_buffer is the size of the buffer channel for querying lighthouses + #query_buffer: 64 + # trigger_buffer is the size of the buffer channel for quickly sending handshakes # after receiving the response for lighthouse queries #trigger_buffer: 64 diff --git a/handshake_manager.go b/handshake_manager.go index 00321d67a..b568cc88d 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -230,7 +230,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger // If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse // Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about // the learned public ip for them. Query again to short circuit the promotion counter - hm.lightHouse.QueryServer(vpnIp, hm.f) + hm.lightHouse.QueryServer(vpnIp) } // Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply @@ -374,13 +374,13 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() + defer hm.Unlock() if hh, ok := hm.vpnIps[vpnIp]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) } - hm.Unlock() return hh.hostinfo } @@ -421,8 +421,7 @@ func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han } } - hm.Unlock() - hm.lightHouse.QueryServer(vpnIp, hm.f) + hm.lightHouse.QueryServer(vpnIp) return hostinfo } diff --git a/hostmap.go b/hostmap.go index df388cd1a..a5adeb996 100644 --- a/hostmap.go +++ b/hostmap.go @@ -561,7 +561,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp, ifce) + ifce.lightHouse.QueryServer(i.vpnIp) } } diff --git a/inside.go b/inside.go index ff9e80b6e..62309628a 100644 --- a/inside.go +++ b/inside.go @@ -288,7 +288,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp, f) + f.lightHouse.QueryServer(hostinfo.vpnIp) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") diff --git a/lighthouse.go b/lighthouse.go index 2193ad3ce..aa54c4bc5 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -74,6 +74,8 @@ type LightHouse struct { // IP's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]iputil.VpnIp] + queryChan chan iputil.VpnIp + calculatedRemotes atomic.Pointer[cidr.Tree4[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote metrics *MessageMetrics @@ -110,6 +112,7 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, nebulaPort: nebulaPort, punchConn: pc, punchy: p, + queryChan: make(chan iputil.VpnIp, c.GetUint32("handshakes.query_buffer", 64)), l: l, } lighthouses := make(map[iputil.VpnIp]struct{}) @@ -139,6 +142,8 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, } }) + h.startQueryWorker() + return &h, nil } @@ -443,9 +448,9 @@ func (lh *LightHouse) loadStaticMap(c *config.C, tunCidr *net.IPNet, staticList return nil } -func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { +func (lh *LightHouse) Query(ip iputil.VpnIp) *RemoteList { if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip, f) + lh.QueryServer(ip) } lh.RLock() if v, ok := lh.addrMap[ip]; ok { @@ -456,30 +461,14 @@ func (lh *LightHouse) Query(ip iputil.VpnIp, f EncWriter) *RemoteList { return nil } -// This is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip iputil.VpnIp, f EncWriter) { - if lh.amLighthouse { - return - } - - if lh.IsLighthouseIP(ip) { - return - } - - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") +// QueryServer is asynchronous so no reply should be expected +func (lh *LightHouse) QueryServer(ip iputil.VpnIp) { + // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseIP(ip) { return } - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - nb := make([]byte, 12, 12) - out := make([]byte, mtu) - for n := range lighthouses { - f.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) - } + lh.queryChan <- ip } func (lh *LightHouse) QueryCache(ip iputil.VpnIp) *RemoteList { @@ -752,6 +741,46 @@ func NewUDPAddrFromLH6(ipp *Ip6AndPort) *udp.Addr { return udp.NewAddr(lhIp6ToIp(ipp), uint16(ipp.Port)) } +func (lh *LightHouse) startQueryWorker() { + if lh.amLighthouse { + return + } + + go func() { + nb := make([]byte, 12, 12) + out := make([]byte, mtu) + + for { + select { + case <-lh.ctx.Done(): + return + case ip := <-lh.queryChan: + lh.innerQueryServer(ip, nb, out) + } + } + }() +} + +func (lh *LightHouse) innerQueryServer(ip iputil.VpnIp, nb, out []byte) { + if lh.IsLighthouseIP(ip) { + return + } + + // Send a query to the lighthouses and hope for the best next time + query, err := NewLhQueryByInt(ip).Marshal() + if err != nil { + lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") + return + } + + lighthouses := lh.GetLighthouses() + lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) + + for n := range lighthouses { + lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + } +} + func (lh *LightHouse) StartUpdateWorker() { interval := lh.GetUpdateInterval() if lh.amLighthouse || interval == 0 { diff --git a/ssh.go b/ssh.go index c410e0101..8e48fc48f 100644 --- a/ssh.go +++ b/ssh.go @@ -518,7 +518,7 @@ func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.Stri } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp, ifce) + rl := ifce.lightHouse.Query(vpnIp) if rl != nil { cm = rl.CopyCache() }