Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cert-v2: backwards compatibility trickery for ipv6 #1245

Merged
merged 4 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ type CachedCertificate struct {
signerFingerprint string
}

func (cc *CachedCertificate) String() string {
return cc.Certificate.String()
}

// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate.
func UnmarshalCertificate(b []byte) (Certificate, error) {
//TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers?
Expand Down
13 changes: 6 additions & 7 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package nebula
import (
"crypto/rand"
"encoding/json"
"fmt"
"sync"
"sync/atomic"

Expand All @@ -26,8 +27,7 @@ type ConnectionState struct {
writeLock sync.Mutex
}

func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern noise.HandshakePattern) *ConnectionState {
crt := cs.GetDefaultCertificate()
func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) {
var dhFunc noise.DHFunc
switch crt.Curve() {
case cert.Curve_CURVE25519:
Expand All @@ -39,8 +39,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
dhFunc = noiseutil.DHP256
}
default:
l.Errorf("invalid curve: %s", crt.Curve())
return nil
return nil, fmt.Errorf("invalid curve: %s", crt.Curve())
}

var ncs noise.CipherSuite
Expand All @@ -53,7 +52,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()}

b := NewBits(ReplayWindow)
// Clear out bit 0, we never transmit it and we don't want it showing as packet loss
// Clear out bit 0, we never transmit it, and we don't want it showing as packet loss
b.Update(l, 0)

hs, err := noise.NewHandshakeState(noise.Config{
Expand All @@ -67,7 +66,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
PresharedKeyPlacement: 0,
})
if err != nil {
return nil
return nil, fmt.Errorf("NewConnectionState: %s", err)
}

// The queue and ready params prevent a counter race that would happen when
Expand All @@ -81,7 +80,7 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, initiator bool, pattern
// always start the counter from 2, as packet 1 and packet 2 are handshake packets.
ci.messageCounter.Add(2)

return ci
return ci, nil
}

func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
Expand Down
1 change: 1 addition & 0 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate {
_, found := c.f.myVpnAddrsTable.Lookup(vpnIp)
if found {
//TODO: we might have 2 certs....
//TODO: this should return our latest version cert
return c.f.pki.getDefaultCertificate().Copy()
}
hi := c.f.hostMap.QueryVpnAddr(vpnIp)
Expand Down
77 changes: 69 additions & 8 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,55 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {
return false
}

// If we're connecting to a v6 address we must use a v2 cert
cs := f.pki.getCertState()
ci := NewConnectionState(f.l, cs, true, noise.HandshakeIX)
v := cs.defaultVersion
for _, a := range hh.hostinfo.vpnAddrs {
if a.Is6() {
v = cert.Version2
break
}
}

crt := cs.getCertificate(v)
if crt == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate is available")
return false
}

crtHs := cs.getHandshakeBytes(v)
if crtHs == nil {
f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Unable to handshake with host because no certificate handshake bytes is available")
}

ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", v).
Error("Failed to create connection state")
return false
}
hh.hostinfo.ConnectionState = ci

hs := &NebulaHandshake{
Details: &NebulaHandshakeDetails{
InitiatorIndex: hh.hostinfo.localIndexId,
Time: uint64(time.Now().UnixNano()),
Cert: cs.getDefaultHandshakeBytes(),
CertVersion: uint32(cs.defaultVersion),
Cert: crtHs,
CertVersion: uint32(v),
},
}

hsBytes, err := hs.Marshal()
if err != nil {
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).
f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message")
return false
}
Expand All @@ -63,22 +96,39 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool {

func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) {
cs := f.pki.getCertState()
ci := NewConnectionState(f.l, cs, false, noise.HandshakeIX)
crt := cs.GetDefaultCertificate()
if crt == nil {
f.l.WithField("udpAddr", addr).
WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).
WithField("certVersion", cs.defaultVersion).
Error("Unable to handshake with host because no certificate is available")
}

ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX)
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to create connection state")
return
}

// Mark packet 1 as seen so it doesn't show up as missed
ci.window.Update(f.l, 1)

msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
if err != nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed to call noise.ReadMessage")
return
}

hs := &NebulaHandshake{}
err = hs.Unmarshal(msg)
if err != nil || hs.Details == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Error("Failed unmarshal handshake message")
return
}

Expand All @@ -98,7 +148,6 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
if remoteCert.Certificate.Version() != ci.myCert.Version() {
// We started off using the wrong certificate version, lets see if we can match the version that was sent to us
rc := cs.getCertificate(remoteCert.Certificate.Version())
//TODO: anywhere we are logging remoteCert needs to be remoteCert.Certificate OR we make a pass through func on CachedCertificate
if rc == nil {
f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert).
Expand Down Expand Up @@ -183,6 +232,18 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet

hs.Details.ResponderIndex = myIndex
hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version())
if hs.Details.Cert == nil {
f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr).
WithField("certName", certName).
WithField("fingerprint", fingerprint).
WithField("issuer", issuer).
WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex).
WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("certVersion", ci.myCert.Version()).
Error("Unable to handshake with host because no certificate handshake bytes is available")
return
}

hs.Details.CertVersion = uint32(ci.myCert.Version())
// Update the time in case their clock is way off from ours
hs.Details.Time = uint64(time.Now().UnixNano())
Expand Down
13 changes: 12 additions & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gaissmai/bart"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -327,7 +328,17 @@ func (f *Interface) reloadFirewall(c *config.C) {
return
}

fw, err := NewFirewallFromConfig(f.l, f.pki.getDefaultCertificate(), c)
cs := f.pki.getCertState()
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}

if certificate == nil {
panic("No certificate available to reconfigure the firewall")
}

fw, err := NewFirewallFromConfig(f.l, certificate, c)
if err != nil {
f.l.WithError(err).Error("Error while creating firewall during reload")
return
Expand Down
79 changes: 55 additions & 24 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,42 +738,73 @@ func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) {
return
}

// Send a query to the lighthouses and hope for the best next time
v := lh.ifce.GetCertState().defaultVersion
msg := &NebulaMeta{
Type: NebulaMeta_HostQuery,
Details: &NebulaMetaDetails{},
}

if v == 1 {
if !addr.Is4() {
lh.l.WithField("vpnAddr", addr).Error("Can't query lighthouse for v6 address using a v1 protocol")
return
var v1Query, v2Query []byte
var err error
var v cert.Version
queried := 0
lighthouses := lh.GetLighthouses()

for lhVpnAddr := range lighthouses {
hi := lh.ifce.GetHostInfo(lhVpnAddr)
if hi != nil {
v = hi.ConnectionState.myCert.Version()
} else {
v = lh.ifce.GetCertState().defaultVersion
}
b := addr.As4()
msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])

} else if v == 2 {
msg.Details.VpnAddr = netAddrToProtoAddr(addr)
if v == cert.Version1 {
if !addr.Is4() {
lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr).
Error("Can't query lighthouse for v6 address using a v1 protocol")
continue
}

} else {
panic("unsupported version")
}
if v1Query == nil {
b := addr.As4()
msg.Details.VpnAddr = nil
msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:])

query, err := msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("vpnAddr", addr).Error("Failed to marshal lighthouse query payload")
return
}
v1Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v1 query payload")
continue
}
}

lighthouses := lh.GetLighthouses()
lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses)))
lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out)
queried++

for n := range lighthouses {
//TODO: there is a slight possibility this lighthouse is using a v2 protocol even if our default is v1
// We could facilitate the move to v2 by marshalling a v2 query
lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out)
} else if v == cert.Version2 {
if v2Query == nil {
msg.Details.OldVpnAddr = 0
msg.Details.VpnAddr = netAddrToProtoAddr(addr)

v2Query, err = msg.Marshal()
if err != nil {
lh.l.WithError(err).WithField("queryVpnAddr", addr).
WithField("lighthouseAddr", lhVpnAddr).
Error("Failed to marshal lighthouse v2 query payload")
continue
}
}

lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out)
queried++

} else {
lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v)
continue
}
}

lh.metricTx(NebulaMeta_HostQuery, int64(queried))
}

func (lh *LightHouse) StartUpdateWorker() {
Expand Down
12 changes: 11 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/netip"
"time"

"github.com/slackhq/nebula/cert"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/overlay"
Expand Down Expand Up @@ -60,7 +62,15 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err)
}

certificate := pki.getDefaultCertificate()
cs := pki.getCertState()
certificate := cs.getCertificate(cert.Version2)
if certificate == nil {
certificate = cs.getCertificate(cert.Version1)
}

if certificate == nil {
panic("No certificates available to configure the firewall")
}
fw, err := NewFirewallFromConfig(l, certificate, c)
if err != nil {
return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err)
Expand Down
Loading
Loading