Skip to content

Commit

Permalink
Clean up a hostinfo to reduce memory usage (#955)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Nov 2, 2023
1 parent 2769783 commit a44e1b8
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 266 deletions.
3 changes: 0 additions & 3 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type ConnectionState struct {
messageCounter atomic.Uint64
window *Bits
writeLock sync.Mutex
ready bool
}

func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
Expand Down Expand Up @@ -71,7 +70,6 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i
H: hs,
initiator: initiator,
window: b,
ready: false,
myCert: certState.Certificate,
}

Expand All @@ -83,6 +81,5 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"certificate": cs.peerCert,
"initiator": cs.initiator,
"message_counter": cs.messageCounter.Load(),
"ready": cs.ready,
})
}
2 changes: 0 additions & 2 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ type ControlHostInfo struct {
LocalIndex uint32 `json:"localIndex"`
RemoteIndex uint32 `json:"remoteIndex"`
RemoteAddrs []*udp.Addr `json:"remoteAddrs"`
CachedPackets int `json:"cachedPackets"`
Cert *cert.NebulaCertificate `json:"cert"`
MessageCounter uint64 `json:"messageCounter"`
CurrentRemote *udp.Addr `json:"currentRemote"`
Expand Down Expand Up @@ -234,7 +233,6 @@ func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CachedPackets: len(h.packetStore),
CurrentRelaysToMe: h.relayState.CopyRelayIps(),
CurrentRelaysThroughMe: h.relayState.CopyRelayForIps(),
}
Expand Down
3 changes: 1 addition & 2 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []*udp.Addr{remote2, remote1},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
CurrentRemote: udp.NewAddr(net.ParseIP("0.0.0.100"), 4444),
Expand All @@ -105,7 +104,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) {
}

// Make sure we don't have any unexpected fields
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "CachedPackets", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi)
test.AssertDeepCopyEqual(t, &expectedInfo, thi)

// Make sure we don't panic if the host info doesn't have a cert yet
Expand Down
31 changes: 0 additions & 31 deletions handshake.go

This file was deleted.

96 changes: 43 additions & 53 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"time"

"github.com/flynn/noise"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
Expand All @@ -13,20 +14,20 @@ import (

// This function constructs a handshake packet, but does not actually send it
// Sending is done by the handshake manager
func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
err := f.handshakeManager.allocateIndex(hostinfo)
func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
err := f.handshakeManager.allocateIndex(hh)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index")
return false
}

certState := f.pki.GetCertState()
ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0)
hostinfo.ConnectionState = ci
hh.hostinfo.ConnectionState = ci

hsProto := &NebulaHandshakeDetails{
InitiatorIndex: hostinfo.localIndexId,
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: certState.RawCertificateNoKey,
}
Expand All @@ -39,7 +40,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
hsBytes, err = hs.Marshal()

if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
Expand All @@ -49,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {

msg, _, _, err := ci.H.WriteMessage(h, hsBytes)
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).
f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage")
return false
}
Expand All @@ -58,9 +59,8 @@ func ixHandshakeStage0(f *Interface, hostinfo *HostInfo) bool {
// handshake packet 1 from the responder
ci.window.Update(f.l, 1)

hostinfo.HandshakePacket[0] = msg
hostinfo.HandshakeReady = true
hostinfo.handshakeStart = time.Now()
hh.hostinfo.HandshakePacket[0] = msg
hh.ready = true
return true
}

Expand Down Expand Up @@ -140,9 +140,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
},
}

hostinfo.Lock()
defer hostinfo.Unlock()

f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
Expand Down Expand Up @@ -208,19 +205,12 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
if err != nil {
switch err {
case ErrAlreadySeen:
// Update remote if preferred (Note we have to switch to locking
// the existing hostinfo, and then switch back so the defer Unlock
// higher in this function still works)
hostinfo.Unlock()
existing.Lock()
// Update remote if preferred
if existing.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
existing.Unlock()
hostinfo.Lock()

msg = existing.HandshakePacket[2]
f.messageMetrics.Tx(header.Handshake, header.MessageSubType(msg[1]), 1)
Expand Down Expand Up @@ -307,7 +297,6 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("sentCachedPackets", len(hostinfo.packetStore)).
Info("Handshake message sent")
}
} else {
Expand All @@ -323,25 +312,26 @@ func ixHandshakeStage1(f *Interface, addr *udp.Addr, via *ViaSender, packet []by
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("sentCachedPackets", len(hostinfo.packetStore)).
Info("Handshake message sent")
}

f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)
hostinfo.ConnectionState.messageCounter.Store(2)
hostinfo.remotes.ResetBlockedRemotes()

return
}

func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *HostInfo, packet []byte, h *header.H) bool {
if hostinfo == nil {
func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hh *HandshakeHostInfo, packet []byte, h *header.H) bool {
if hh == nil {
// Nothing here to tear down, got a bogus stage 2 packet
return true
}

hostinfo.Lock()
defer hostinfo.Unlock()
hh.Lock()
defer hh.Unlock()

hostinfo := hh.hostinfo
if addr != nil {
if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
Expand All @@ -350,22 +340,6 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
}

ci := hostinfo.ConnectionState
if ci.ready {
f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h).
Info("Handshake is already complete")

// Update remote if preferred
if hostinfo.SetRemoteIfPreferred(f.hostMap, addr) {
// Send a test packet to ensure the other side has also switched to
// the preferred remote
f.SendMessageToVpnIp(header.Test, header.TestRequest, hostinfo.vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}

// We already have a complete tunnel, there is nothing that can be done by processing further stage 1 packets
return false
}

msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).
Expand Down Expand Up @@ -422,22 +396,22 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
f.handshakeManager.DeleteHostInfo(hostinfo)

// Create a new hostinfo/handshake for the intended vpn ip
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHostInfo *HostInfo) {
f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) {
//TODO: this doesnt know if its being added or is being used for caching a packet
// Block the current used address
newHostInfo.remotes = hostinfo.remotes
newHostInfo.remotes.BlockRemote(addr)
newHH.hostinfo.remotes = hostinfo.remotes
newHH.hostinfo.remotes.BlockRemote(addr)

// Get the correct remote list for the host we did handshake with
hostinfo.remotes = f.lightHouse.QueryCache(vpnIp)

f.l.WithField("blockedUdpAddrs", newHostInfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHostInfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()).WithField("vpnIp", vpnIp).
WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.preferredRanges)).
Info("Blocked addresses for handshakes")

// Swap the packet store to benefit the original intended recipient
newHostInfo.packetStore = hostinfo.packetStore
hostinfo.packetStore = []*cachedPacket{}
newHH.packetStore = hh.packetStore
hh.packetStore = []*cachedPacket{}

// Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down
hostinfo.vpnIp = vpnIp
Expand All @@ -450,15 +424,15 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
// Mark packet 2 as seen so it doesn't show up as missed
ci.window.Update(f.l, 2)

duration := time.Since(hostinfo.handshakeStart).Nanoseconds()
duration := time.Since(hh.startTime).Nanoseconds()
f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).
WithField("durationNs", duration).
WithField("sentCachedPackets", len(hostinfo.packetStore)).
WithField("sentCachedPackets", len(hh.packetStore)).
Info("Handshake message received")

hostinfo.remoteIndexId = hs.Details.ResponderIndex
Expand All @@ -482,7 +456,23 @@ func ixHandshakeStage2(f *Interface, addr *udp.Addr, via *ViaSender, hostinfo *H
// Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp
f.handshakeManager.Complete(hostinfo, f)
f.connectionManager.AddTrafficWatch(hostinfo.localIndexId)
hostinfo.handshakeComplete(f.l, f.cachedPacketMetrics)

hostinfo.ConnectionState.messageCounter.Store(2)

if f.l.Level >= logrus.DebugLevel {
hostinfo.logger(f.l).Debugf("Sending %d stored packets", len(hh.packetStore))
}

if len(hh.packetStore) > 0 {
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range hh.packetStore {
cp.callback(cp.messageType, cp.messageSubType, hostinfo, cp.packet, nb, out)
}
f.cachedPacketMetrics.sent.Inc(int64(len(hh.packetStore)))
}

hostinfo.remotes.ResetBlockedRemotes()
f.metricHandshakes.Update(duration)

return false
Expand Down
Loading

0 comments on commit a44e1b8

Please sign in to comment.