Skip to content

Commit

Permalink
Support reloading preferred_ranges
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus committed Dec 15, 2023
1 parent 8be9792 commit 7205659
Show file tree
Hide file tree
Showing 11 changed files with 110 additions and 84 deletions.
2 changes: 1 addition & 1 deletion connection_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) {
}

if n.punchy.GetTargetEverything() {
hostinfo.remotes.ForEach(n.hostMap.preferredRanges, func(addr *udp.Addr, preferred bool) {
hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr *udp.Addr, preferred bool) {
n.metricsTxPunchy.Inc(1)
n.intf.outside.WriteTo([]byte{1}, addr)
})
Expand Down
11 changes: 8 additions & 3 deletions connection_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ func Test_NewConnectionManagerTest(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}

// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)

cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Expand Down Expand Up @@ -122,7 +124,9 @@ func Test_NewConnectionManagerTest2(t *testing.T) {
preferredRanges := []*net.IPNet{localrange}

// Very incomplete mock objects
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)

cs := &CertState{
RawCertificate: []byte{},
PrivateKey: []byte{},
Expand Down Expand Up @@ -209,7 +213,8 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
_, vpncidr, _ := net.ParseCIDR("172.1.1.1/24")
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
preferredRanges := []*net.IPNet{localrange}
hostMap := NewHostMap(l, vpncidr, preferredRanges)
hostMap := newHostMap(l, vpncidr)
hostMap.preferredRanges.Store(&preferredRanges)

// Generate keys for CA and peer's cert.
pubCA, privCA, _ := ed25519.GenerateKey(rand.Reader)
Expand Down
4 changes: 2 additions & 2 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp iputil.VpnIp, pending bool) *ControlH
return nil
}

ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
ch := copyHostInfo(h, c.f.hostMap.GetPreferredRanges())
return &ch
}

Expand All @@ -157,7 +157,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp iputil.VpnIp, addr udp.Addr) *Control
}

hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
ch := copyHostInfo(hostInfo, c.f.hostMap.GetPreferredRanges())
return &ch
}

Expand Down
4 changes: 3 additions & 1 deletion control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
l := test.NewLogger()
// Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object
// To properly ensure we are not exposing core memory to the caller
hm := NewHostMap(l, &net.IPNet{}, make([]*net.IPNet, 0))
hm := newHostMap(l, &net.IPNet{})
hm.preferredRanges.Store(&[]*net.IPNet{})

remote1 := udp.NewAddr(net.ParseIP("0.0.0.100"), 4444)
remote2 := udp.NewAddr(net.ParseIP("1:2:3:4:5:6:7:8"), 4444)
ipNet := net.IPNet{
Expand Down
2 changes: 1 addition & 1 deletion handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *Handsha
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
Expand Down
10 changes: 5 additions & 5 deletions handshake_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)).
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Expand Down Expand Up @@ -211,7 +211,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
}

remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.preferredRanges)
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)

// We only care about a lighthouse trigger if we have new remotes to send to.
Expand All @@ -235,7 +235,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTrigger

// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr
hostinfo.remotes.ForEach(hm.mainHostMap.preferredRanges, func(addr *udp.Addr, _ bool) {
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
Expand Down Expand Up @@ -362,7 +362,7 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*Han
hm.mainHostMap.RUnlock()
// Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.preferredRanges, hm.f)
h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
}
return h, true
}
Expand Down Expand Up @@ -600,7 +600,7 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
}

func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.preferredRanges
return c.mainHostMap.GetPreferredRanges()
}

func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
Expand Down
4 changes: 3 additions & 1 deletion handshake_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) {
_, localrange, _ := net.ParseCIDR("10.1.1.1/24")
ip := iputil.Ip2VpnIp(net.ParseIP("172.1.1.2"))
preferredRanges := []*net.IPNet{localrange}
mainHM := NewHostMap(l, vpncidr, preferredRanges)
mainHM := newHostMap(l, vpncidr)
mainHM.preferredRanges.Store(&preferredRanges)

lh := newTestLighthouse()

cs := &CertState{
Expand Down
71 changes: 52 additions & 19 deletions hostmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/cidr"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
Expand Down Expand Up @@ -57,9 +58,8 @@ type HostMap struct {
Relays map[uint32]*HostInfo // Maps a Relay IDX to a Relay HostInfo object
RemoteIndexes map[uint32]*HostInfo
Hosts map[iputil.VpnIp]*HostInfo
preferredRanges []*net.IPNet
preferredRanges atomic.Pointer[[]*net.IPNet]
vpnCIDR *net.IPNet
metricsEnabled bool
l *logrus.Logger
}

Expand Down Expand Up @@ -254,21 +254,53 @@ type cachedPacketMetrics struct {
dropped metrics.Counter
}

func NewHostMap(l *logrus.Logger, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[iputil.VpnIp]*HostInfo{}
i := map[uint32]*HostInfo{}
r := map[uint32]*HostInfo{}
relays := map[uint32]*HostInfo{}
m := HostMap{
Indexes: i,
Relays: relays,
RemoteIndexes: r,
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
l: l,
func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR *net.IPNet, c *config.C) *HostMap {
hm := newHostMap(l, vpnCIDR)

hm.reload(c, true)
c.RegisterReloadCallback(func(c *config.C) {
hm.reload(c, false)
})

l.WithField("network", hm.vpnCIDR.String()).
WithField("preferredRanges", hm.GetPreferredRanges()).
Info("Main HostMap created")

return hm
}

func newHostMap(l *logrus.Logger, vpnCIDR *net.IPNet) *HostMap {
return &HostMap{
Indexes: map[uint32]*HostInfo{},
Relays: map[uint32]*HostInfo{},
RemoteIndexes: map[uint32]*HostInfo{},
Hosts: map[iputil.VpnIp]*HostInfo{},
vpnCIDR: vpnCIDR,
l: l,
}
}

func (hm *HostMap) reload(c *config.C, initial bool) {
if initial || c.HasChanged("preferred_ranges") {
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})

for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)

if err != nil {
hm.l.WithError(err).WithField("range", rawPreferredRanges).Warn("Failed to parse preferred ranges, ignoring")
continue
}

preferredRanges = append(preferredRanges, preferredRange)
}

oldRanges := hm.preferredRanges.Swap(&preferredRanges)
if !initial {
hm.l.WithField("oldPreferredRanges", *oldRanges).WithField("newPreferredRanges", preferredRanges).Info("preferred_ranges changed")
}
}
return &m
}

// EmitStats reports host, index, and relay counts to the stats collection system
Expand Down Expand Up @@ -457,7 +489,7 @@ func (hm *HostMap) queryVpnIp(vpnIp iputil.VpnIp, promoteIfce *Interface) *HostI
hm.RUnlock()
// Do not attempt promotion if you are a lighthouse
if promoteIfce != nil && !promoteIfce.lightHouse.amLighthouse {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
h.TryPromoteBest(hm.GetPreferredRanges(), promoteIfce)
}
return h

Expand Down Expand Up @@ -504,7 +536,8 @@ func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) {
}

func (hm *HostMap) GetPreferredRanges() []*net.IPNet {
return hm.preferredRanges
//NOTE: if preferredRanges is ever not stored before a load this will fail to dereference a nil pointer
return *hm.preferredRanges.Load()
}

func (hm *HostMap) ForEachVpnIp(f controlEach) {
Expand Down Expand Up @@ -596,7 +629,7 @@ func (i *HostInfo) SetRemoteIfPreferred(hm *HostMap, newRemote *udp.Addr) bool {
// NOTE: We do this loop here instead of calling `isPreferred` in
// remote_list.go so that we only have to loop over preferredRanges once.
newIsPreferred := false
for _, l := range hm.preferredRanges {
for _, l := range hm.GetPreferredRanges() {
// return early if we are already on a preferred remote
if l.Contains(currentRemote.IP) {
return false
Expand Down
37 changes: 33 additions & 4 deletions hostmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ import (
"net"
"testing"

"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
)

func TestHostMap_MakePrimary(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
hm := newHostMap(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)

f := &Interface{}
Expand Down Expand Up @@ -91,13 +91,12 @@ func TestHostMap_MakePrimary(t *testing.T) {

func TestHostMap_DeleteHostInfo(t *testing.T) {
l := test.NewLogger()
hm := NewHostMap(
hm := newHostMap(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
[]*net.IPNet{},
)

f := &Interface{}
Expand Down Expand Up @@ -205,3 +204,33 @@ func TestHostMap_DeleteHostInfo(t *testing.T) {
prim = hm.QueryVpnIp(1)
assert.Nil(t, prim)
}

func TestHostMap_reload(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)

hm := NewHostMapFromConfig(
l,
&net.IPNet{
IP: net.IP{10, 0, 0, 1},
Mask: net.IPMask{255, 255, 255, 0},
},
c,
)

toS := func(ipn []*net.IPNet) []string {
var s []string
for _, n := range ipn {
s = append(s, n.String())
}
return s
}

assert.Empty(t, hm.GetPreferredRanges())

c.ReloadConfigString("preferred_ranges: [1.1.1.0/24, 10.1.1.0/24]")
assert.EqualValues(t, []string{"1.1.1.0/24", "10.1.1.0/24"}, toS(hm.GetPreferredRanges()))

c.ReloadConfigString("preferred_ranges: [1.1.1.1/32]")
assert.EqualValues(t, []string{"1.1.1.1/32"}, toS(hm.GetPreferredRanges()))
}
47 changes: 1 addition & 46 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,52 +173,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
}
}

// Set up my internal host map
var preferredRanges []*net.IPNet
rawPreferredRanges := c.GetStringSlice("preferred_ranges", []string{})
// First, check if 'preferred_ranges' is set and fallback to 'local_range'
if len(rawPreferredRanges) > 0 {
for _, rawPreferredRange := range rawPreferredRanges {
_, preferredRange, err := net.ParseCIDR(rawPreferredRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse preferred ranges", err)
}
preferredRanges = append(preferredRanges, preferredRange)
}
}

// local_range was superseded by preferred_ranges. If it is still present,
// merge the local_range setting into preferred_ranges. We will probably
// deprecate local_range and remove in the future.
rawLocalRange := c.GetString("local_range", "")
if rawLocalRange != "" {
_, localRange, err := net.ParseCIDR(rawLocalRange)
if err != nil {
return nil, util.ContextualizeIfNeeded("Failed to parse local_range", err)
}

// Check if the entry for local_range was already specified in
// preferred_ranges. Don't put it into the slice twice if so.
var found bool
for _, r := range preferredRanges {
if r.String() == localRange.String() {
found = true
break
}
}
if !found {
preferredRanges = append(preferredRanges, localRange)
}
}

hostMap := NewHostMap(l, tunCidr, preferredRanges)
hostMap.metricsEnabled = c.GetBool("stats.message_metrics", false)

l.
WithField("network", hostMap.vpnCIDR.String()).
WithField("preferredRanges", hostMap.preferredRanges).
Info("Main HostMap created")

hostMap := NewHostMapFromConfig(l, tunCidr, c)
punchy := NewPunchyFromConfig(l, c)
lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr
enc.SetIndent("", " ")
}

return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.preferredRanges))
return enc.Encode(copyHostInfo(hostInfo, ifce.hostMap.GetPreferredRanges()))
}

func sshReload(c *config.C, w sshd.StringWriter) error {
Expand Down

0 comments on commit 7205659

Please sign in to comment.