Skip to content

Commit

Permalink
Fix hostmap deletion and lighthouse version choice
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Oct 11, 2024
1 parent b3f2d49 commit c00422f
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 37 deletions.
4 changes: 4 additions & 0 deletions handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,7 @@ func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {}
func (mw *mockEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return nil
}

func (mw *mockEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: cert.Version2}
}
16 changes: 8 additions & 8 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,31 +352,31 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) {
h := hm.Hosts[addr]
for h != nil {
if h == hostinfo {
hm.unlockedInnerDeleteHostInfo(h)
hm.unlockedInnerDeleteHostInfo(h, addr)
}
h = h.next
}
}
}

func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo) {
primary, ok := hm.Hosts[hostinfo.vpnAddrs[0]]
func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) {
primary, ok := hm.Hosts[addr]
if ok && primary == hostinfo {
// The vpnIp pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, hostinfo.vpnAddrs[0])
// The vpn addr pointer points to the same hostinfo as the local index id, we can remove it
delete(hm.Hosts, addr)
if len(hm.Hosts) == 0 {
hm.Hosts = map[netip.Addr]*HostInfo{}
}

if hostinfo.next != nil {
// We had more than 1 hostinfo at this vpnip, promote the next in the list to primary
hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo.next
// We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary
hm.Hosts[addr] = hostinfo.next
// It is primary, there is no previous hostinfo now
hostinfo.next.prev = nil
}

} else {
// Relink if we were in the middle of multiple hostinfos for this vpn ip
// Relink if we were in the middle of multiple hostinfos for this vpn addr
if hostinfo.prev != nil {
hostinfo.prev.next = hostinfo.next
}
Expand Down
5 changes: 5 additions & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ type EncWriter interface {
SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte)
Handshake(vpnIp netip.Addr)
GetHostInfo(vpnIp netip.Addr) *HostInfo
GetCertState() *CertState
}

type sendRecvErrorConfig uint8
Expand Down Expand Up @@ -428,6 +429,10 @@ func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return f.hostMap.QueryVpnAddr(vpnIp)
}

func (f *Interface) GetCertState() *CertState {
return f.pki.getCertState()
}

func (f *Interface) Close() error {
f.closed.Store(true)

Expand Down
35 changes: 9 additions & 26 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ type LightHouse struct {
staticList atomic.Pointer[map[netip.Addr]struct{}]
lighthouses atomic.Pointer[map[netip.Addr]struct{}]

interval atomic.Int64
updateCancel context.CancelFunc
ifce EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16
protocolVersion atomic.Uint32 // The default protocol version to use if we can't determine which to use from the tunnel
interval atomic.Int64
updateCancel context.CancelFunc
ifce EncWriter
nebulaPort uint32 // 32 bits because protobuf does not have a uint16

advertiseAddrs atomic.Pointer[[]netip.AddrPort]

Expand Down Expand Up @@ -352,16 +351,6 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error {
}
}

v := c.GetUint32("pki.default_version", 1)
switch v {
case 1:
lh.protocolVersion.Store(1)
case 2:
lh.protocolVersion.Store(2)
default:
return fmt.Errorf("invalid version for lighthouse: %v", v)
}

return nil
}

Expand Down Expand Up @@ -750,15 +739,12 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
}

// Send a query to the lighthouses and hope for the best next time
//TODO: this is not sufficient since the version depends on the certs loaded into memory as well
v := lh.protocolVersion.Load()
v := lh.ifce.GetCertState().defaultVersion
msg := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{},
}

//TODO: remove this
v = 2
if v == 1 {
if !addr.Is4() {
lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol")
Expand Down Expand Up @@ -843,7 +829,7 @@ func (lh *LightHouse) SendUpdate() {
}
}

v := lh.protocolVersion.Load()
v := lh.ifce.GetCertState().defaultVersion
msg := &NebulaMeta{
Type: NebulaMeta_HostUpdateNotification,
Details: &NebulaMetaDetails{
Expand All @@ -852,8 +838,6 @@ func (lh *LightHouse) SendUpdate() {
},
}

//TODO: remove this
v = 2
if v == 1 {
var relays []uint32
for _, r := range lh.GetRelaysForMe() {
Expand Down Expand Up @@ -1042,11 +1026,10 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []neti
found, ln, err = lhh.lh.queryAndPrepMessage(fromVpnAddrs[0], func(c *cache) (int, error) {
n = lhh.resetMeta()
n.Type = NebulaMeta_HostPunchNotification
//TODO: unsure which version to use. If we had access to the hostmap we could see if there is already a tunnel
// and use that version then fallback to our default configuration
targetHI := lhh.lh.ifce.GetHostInfo(queryVpnIp)
useVersion = cert.Version(lhh.lh.protocolVersion.Load())
if targetHI != nil {
if targetHI == nil {
useVersion = lhh.lh.ifce.GetCertState().defaultVersion
} else {
useVersion = targetHI.GetCert().Certificate.Version()
}

Expand Down
10 changes: 8 additions & 2 deletions lighthouse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"testing"

"github.com/gaissmai/bart"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test"
Expand Down Expand Up @@ -427,8 +428,9 @@ type testLhReply struct {
}

type testEncWriter struct {
lastReply testLhReply
metaFilter *NebulaMeta_MessageType
lastReply testLhReply
metaFilter *NebulaMeta_MessageType
protocolVersion cert.Version
}

func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) {
Expand Down Expand Up @@ -474,6 +476,10 @@ func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo {
return nil
}

func (tw *testEncWriter) GetCertState() *CertState {
return &CertState{defaultVersion: tw.protocolVersion}
}

// assertIp4InArray asserts every address in want is at the same position in have and that the lengths match
func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) {
if !assert.Len(t, have, len(want)) {
Expand Down
16 changes: 15 additions & 1 deletion pki.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,24 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) {
}
}

rawDefaultVersion := c.GetUint32("pki.default_version", 1)
if v1 == nil && v2 == nil {
return nil, errors.New("no certificates found in pki.cert")
}

useDefaultVersion := uint32(1)
if v1 == nil {
// The only condition that requires v2 as the default is if only a v2 certificate is present
// We do this to avoid having to configure it specifically in the config file
useDefaultVersion = 2
}

rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion)
var defaultVersion cert.Version
switch rawDefaultVersion {
case 1:
if v1 == nil {
return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert")
}
defaultVersion = cert.Version1
case 2:
defaultVersion = cert.Version2
Expand Down

0 comments on commit c00422f

Please sign in to comment.