From 554f849f7545c29b701988ee23f0ba0ca7b3c088 Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Sat, 21 Sep 2024 20:25:18 -0500 Subject: [PATCH] Support multiple vpn addrs in lighthouse and hostmap --- connection_manager.go | 2 +- control_tester.go | 4 +- handshake_ix.go | 19 +++-- handshake_manager.go | 2 +- hostmap.go | 18 ++++- lighthouse.go | 174 +++++++++++++++++++++++------------------- lighthouse_test.go | 6 +- outside.go | 23 +++--- 8 files changed, 140 insertions(+), 108 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index f331abacf..9c38b5d87 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnAddr(hostinfo.vpnAddrs[0]) + n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: diff --git a/control_tester.go b/control_tester.go index 586617af7..93c4e06a8 100644 --- a/control_tester.go +++ b/control_tester.go @@ -49,7 +49,7 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() @@ -65,7 +65,7 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) // This is necessary to inform an initiator of possible relays for communicating with a responder func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() diff --git a/handshake_ix.go b/handshake_ix.go index bd5e3ff9c..df853ab1a 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,6 +2,7 @@ package nebula import ( "net/netip" + "slices" "time" "github.com/flynn/noise" @@ -230,7 +231,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs[0]) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) hostinfo.CreateRemoteCIDR(remoteCert.Certificate) @@ -436,9 +437,13 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, n := range vpnNetworks { + vpnAddrs[i] = n.Addr() + } + // Ensure the right host responded - //TODO: this is a horribly broken test - if vpnNetworks[0].Addr() != hostinfo.vpnAddrs[0] { + if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). @@ -455,7 +460,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.hostinfo.remotes.BlockRemote(addr) // Get the correct remote list for the host we did handshake with - hostinfo.remotes = f.lightHouse.QueryCache(vpnNetworks[0].Addr()) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnNetworks", vpnNetworks). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). @@ -466,10 +471,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hh.packetStore = []*cachedPacket{} // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnAddrs = nil - for _, n := range vpnNetworks { - hostinfo.vpnAddrs = append(hostinfo.vpnAddrs, n.Addr()) - } + hostinfo.vpnAddrs = vpnAddrs f.sendCloseTunnel(hostinfo) }) @@ -492,6 +494,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo.remoteIndexId = hs.Details.ResponderIndex hostinfo.lastHandshakeTime = hs.Details.Time + hostinfo.vpnAddrs = vpnAddrs // Store their cert and our symmetric keys ci.peerCert = remoteCert diff --git a/handshake_manager.go b/handshake_manager.go index 258d5ae94..6b3902dfa 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -209,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) diff --git a/hostmap.go b/hostmap.go index fbafc06d5..63601ee37 100644 --- a/hostmap.go +++ b/hostmap.go @@ -308,7 +308,7 @@ func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { hm.Lock() // If we have a previous or next hostinfo then we are not the last one for this vpn ip final := (hostinfo.next == nil && hostinfo.prev == nil) - hm.unlockedDeleteHostInfo(hostinfo) + hm.unlockedDeleteHostInfo(hostinfo, false) hm.Unlock() return final @@ -345,7 +345,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { hostinfo.prev = nil } -func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { +func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo, dontRecurse bool) { primary, ok := hm.Hosts[hostinfo.vpnAddrs[0]] if ok && primary == hostinfo { // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it @@ -399,6 +399,18 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } + + if !dontRecurse { + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedDeleteHostInfo(h, true) + } + h = h.next + } + } + } } func (hm *HostMap) QueryIndex(index uint32) *HostInfo { @@ -501,7 +513,7 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { check := hostinfo for check != nil { if i > MaxHostInfosPerVpnIp { - hm.unlockedDeleteHostInfo(check) + hm.unlockedDeleteHostInfo(check, false) } check = check.next i++ diff --git a/lighthouse.go b/lighthouse.go index 23a4b9f50..5549e8386 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/netip" + "slices" "strconv" "sync" "sync/atomic" @@ -472,12 +473,12 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return nil } -func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { - if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip) +func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { + if !lh.IsLighthouseIP(vpnAddr) { + lh.QueryServer(vpnAddr) } lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } @@ -486,18 +487,18 @@ func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip netip.Addr) { +func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses - if lh.amLighthouse || lh.IsLighthouseIP(ip) { + if lh.amLighthouse || lh.IsLighthouseIP(vpnAddr) { return } - lh.queryChan <- ip + lh.queryChan <- vpnAddr } -func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { +func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } @@ -506,16 +507,16 @@ func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(ip) + return lh.unlockedGetRemoteList(vpnAddrs) } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing // details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? - if v, ok := lh.addrMap[vpnIp]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() @@ -523,7 +524,7 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, lh.RUnlock() // vpnIp should also be the owner here since we are a lighthouse. - c := v.cache[vpnIp] + c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) @@ -535,20 +536,25 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, return false, 0, nil } -func (lh *LightHouse) DeleteVpnAddr(vpnIp netip.Addr) { +func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { return } lh.Lock() - //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIp) - - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", vpnIp) + rm, ok := lh.addrMap[allVpnAddrs[0]] + if ok { + for _, addr := range allVpnAddrs { + srm := lh.addrMap[addr] + if srm == rm { + delete(lh.addrMap, addr) + if lh.l.Level >= logrus.DebugLevel { + lh.l.Debugf("deleting %s from lighthouse.", addr) + } + } + } } - lh.Unlock() } @@ -556,9 +562,9 @@ func (lh *LightHouse) DeleteVpnAddr(vpnIp netip.Addr) { // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx @@ -572,12 +578,12 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t am.shouldRebuild = true }) if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) for _, addrPort := range hr.GetIPs() { - if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { continue } switch { @@ -589,49 +595,52 @@ func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, t } // Mark it as static in the caller provided map - staticList[vpnIp] = struct{}{} + staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { //TODO: this needs to support v6 addresses too tree := lh.getCalculatedRemotes() if tree == nil { return false } - calculatedRemotes, ok := tree.Lookup(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } var calculated []*V4AddrPort for _, cr := range calculatedRemotes { - c := cr.Apply(vpnIp) + c := cr.Apply(vpnAddr) if c != nil { calculated = append(calculated, c) } } lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculated, lh.unlockedShouldAddV4) return len(calculated) > 0 } -// unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { - am, ok := lh.addrMap[vpnIp] +// unlockedGetRemoteList +// assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { + am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) - lh.addrMap[vpnIp] = am + am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } } return am } @@ -693,13 +702,25 @@ func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *V6AddrPort) bool return true } -func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { - if _, ok := lh.GetLighthouses()[vpnIp]; ok { +func (lh *LightHouse) IsLighthouseIP(vpnAddr netip.Addr) bool { + if _, ok := lh.GetLighthouses()[vpnAddr]; ok { return true } return false } +// TODO: IsLighthouseIP should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake +// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially +func (lh *LightHouse) IsAnyLighthouseIP(vpnAddr []netip.Addr) bool { + l := lh.GetLighthouses() + for _, a := range vpnAddr { + if _, ok := l[a]; ok { + return true + } + } + return false +} + func (lh *LightHouse) startQueryWorker() { if lh.amLighthouse { return @@ -915,20 +936,18 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { return lhh.meta } -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, reqHostinfo *HostInfo, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") - //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") - //TODO: send recv_error? return } @@ -936,24 +955,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, reqHostinfo *H switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, reqHostinfo, rAddr, w) + lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, reqHostinfo) + lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, reqHostinfo, w) + lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, reqHostinfo, w) + lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostInfo, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -1001,15 +1020,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostIn } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(reqHostinfo.vpnAddrs[0], func(c *cache) (int, error) { + found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification //TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel @@ -1021,15 +1040,15 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostIn } if useVersion == cert.Version1 { - if !reqHostinfo.vpnAddrs[0].Is4() { + if !fromVpnAddrs[0].Is4() { return 0, fmt.Errorf("invalid vpn ip for v1 handleHostQuery") } - b := reqHostinfo.vpnAddrs[0].As4() + b := fromVpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) lhh.coalesceAnswers(useVersion, c, n) } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0]) + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) lhh.coalesceAnswers(useVersion, c, n) } else { @@ -1044,7 +1063,7 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, reqHostinfo *HostIn } if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } @@ -1094,9 +1113,8 @@ func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *Nebul } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *HostInfo) { - //TODO: this is kind of dumb - if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { + if !lhh.lh.IsAnyLighthouseIP(fromVpnAddrs) { return } @@ -1111,12 +1129,12 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *H certVpnIp = protoAddrToNetAddr(n.Details.VpnAddr) } - am := lhh.lh.unlockedGetRemoteList(certVpnIp) + am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnIp}) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(reqHostinfo.vpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(fromVpnAddrs[0], certVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], certVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) var relays []netip.Addr if len(n.Details.OldRelayVpnAddrs) > 0 { @@ -1133,7 +1151,7 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *H } } - am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], certVpnIp, relays) + am.unlockedSetRelay(fromVpnAddrs[0], certVpnIp, relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block @@ -1143,10 +1161,10 @@ func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, reqHostinfo *H } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", reqHostinfo.vpnAddrs) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) } return } @@ -1167,20 +1185,20 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHos //todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? //todo why do we care about the vpnip in the packet? We know where it came from, right? - if detailsVpnIp != reqHostinfo.vpnAddrs[0] { + if !slices.Contains(fromVpnAddrs, detailsVpnIp) { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") } return } lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(reqHostinfo.vpnAddrs[0]) + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(reqHostinfo.vpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetV4(fromVpnAddrs[0], detailsVpnIp, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], detailsVpnIp, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) var relays []netip.Addr if len(n.Details.OldRelayVpnAddrs) > 0 { @@ -1197,22 +1215,22 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHos } } - am.unlockedSetRelay(reqHostinfo.vpnAddrs[0], detailsVpnIp, relays) + am.unlockedSetRelay(fromVpnAddrs[0], detailsVpnIp, relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck if useVersion == cert.Version1 { - if !reqHostinfo.vpnAddrs[0].Is4() { - lhh.l.WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + if !fromVpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") return } - vpnIpB := reqHostinfo.vpnAddrs[0].As4() + vpnIpB := fromVpnAddrs[0].As4() n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnIpB[:]) } else if useVersion == cert.Version2 { - n.Details.VpnAddr = netAddrToProtoAddr(reqHostinfo.vpnAddrs[0]) + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) } else { panic("unsupported version") @@ -1220,17 +1238,17 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, reqHos ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnAddrs", reqHostinfo.vpnAddrs).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, reqHostinfo.vpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnIp(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, reqHostinfo *HostInfo, w EncWriter) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { //TODO: this is kinda stupid - if !lhh.lh.IsLighthouseIP(reqHostinfo.vpnAddrs[0]) { + if !lhh.lh.IsAnyLighthouseIP(fromVpnAddrs) { return } diff --git a/lighthouse_test.go b/lighthouse_test.go index 0c315c09c..2cdfce79a 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -135,7 +135,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { mw := &mockEncWriter{} - hi := &HostInfo{vpnAddrs: []netip.Addr{vpnIp2}} + hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ @@ -325,7 +325,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{myVpnIp}}, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } @@ -350,7 +350,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, &HostInfo{vpnAddrs: []netip.Addr{vpnIp}}, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } //TODO: this is a RemoteList test diff --git a/outside.go b/outside.go index a94ac9c17..f504bb406 100644 --- a/outside.go +++ b/outside.go @@ -145,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf.HandleRequest(ip, hostinfo, d, f) + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic @@ -230,9 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { - // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - //TODO: we should delete all related vpnaddrs too - f.lightHouse.DeleteVpnAddr(hostInfo.vpnAddrs[0]) + // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } @@ -241,26 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { - if ip.IsValid() && hostinfo.remote != ip { +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) { + if vpnAddr.IsValid() && hostinfo.remote != vpnAddr { //TODO: this is weird now that we can have multiple vpn addrs - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], ip.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(ip) + hostinfo.SetRemote(vpnAddr) } }