From 14875480f3ae6e713a9f3b02339d0bbf16d93ecb Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Wed, 13 Nov 2024 17:03:35 -0600 Subject: [PATCH] PSK support for v2 --- connection_state.go | 15 ++-- e2e/handshakes_test.go | 132 +++++++++++++++++++++++++++++++++ e2e/router/router.go | 7 +- examples/config.yml | 33 ++++++++- handshake_ix.go | 59 ++++++++++----- handshake_manager_test.go | 5 ++ pki.go | 12 +++ psk.go | 150 ++++++++++++++++++++++++++++++++++++++ psk_test.go | 71 ++++++++++++++++++ 9 files changed, 451 insertions(+), 33 deletions(-) create mode 100644 psk.go create mode 100644 psk_test.go diff --git a/connection_state.go b/connection_state.go index faee443de..eb7342c72 100644 --- a/connection_state.go +++ b/connection_state.go @@ -27,7 +27,7 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { +func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern, psk []byte) (*ConnectionState, error) { var dhFunc noise.DHFunc switch crt.Curve() { case cert.Curve_CURVE25519: @@ -56,13 +56,12 @@ func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, i b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: ncs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - //NOTE: These should come from CertState (pki.go) when we finally implement it - PresharedKey: []byte{}, + CipherSuite: ncs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + PresharedKey: psk, PresharedKeyPlacement: 0, }) if err != nil { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 29b9d536e..09811f0f5 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -1105,6 +1105,138 @@ func TestV2NonPrimaryWithLighthouse(t *testing.T) { theirControl.Stop() } +func TestPSK(t *testing.T) { + tests := []struct { + name string + myPskMode nebula.PskMode + theirPskMode nebula.PskMode + }{ + // All accepting + { + name: "both accepting", + myPskMode: nebula.PskAccepting, + theirPskMode: nebula.PskAccepting, + }, + + // accepting and sending both ways + { + name: "accepting to sending", + myPskMode: nebula.PskAccepting, + theirPskMode: nebula.PskSending, + }, + { + name: "sending to accepting", + myPskMode: nebula.PskSending, + theirPskMode: nebula.PskAccepting, + }, + + // All sending + { + name: "sending to sending", + myPskMode: nebula.PskSending, + theirPskMode: nebula.PskSending, + }, + + // enforced and sending both ways + { + name: "enforced to sending", + myPskMode: nebula.PskEnforced, + theirPskMode: nebula.PskSending, + }, + { + name: "sending to enforced", + myPskMode: nebula.PskSending, + theirPskMode: nebula.PskEnforced, + }, + + // All enforced + { + name: "both enforced", + myPskMode: nebula.PskEnforced, + theirPskMode: nebula.PskEnforced, + }, + + // Enforced can technically handshake with an accepting node, but it is bad to be in this state + { + name: "enforced to accepting", + myPskMode: nebula.PskEnforced, + theirPskMode: nebula.PskAccepting, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var myPskSettings, theirPskSettings m + + switch test.myPskMode { + case nebula.PskAccepting: + myPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage0", "this is a key"}}} + case nebula.PskSending: + myPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage1"}}} + case nebula.PskEnforced: + myPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage2"}}} + } + + switch test.theirPskMode { + case nebula.PskAccepting: + theirPskSettings = m{"psk": &m{"mode": "accepting", "keys": []string{"garbage3", "this is a key"}}} + case nebula.PskSending: + theirPskSettings = m{"psk": &m{"mode": "sending", "keys": []string{"this is a key", "garbage4"}}} + case nebula.PskEnforced: + theirPskSettings = m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key", "garbage5"}}} + } + + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + myControl, myVpnIp, myUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "me", "10.0.0.1/24", myPskSettings) + theirControl, theirVpnIp, theirUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.0.0.2/24", theirPskSettings) + + myControl.InjectLightHouseAddr(theirVpnIp[0].Addr(), theirUdpAddr) + r := router.NewR(t, myControl, theirControl) + + // Start the servers + myControl.Start() + theirControl.Start() + + t.Log("Route until we see our cached packet flow") + myControl.InjectTunUDPPacket(theirVpnIp[0].Addr(), 80, myVpnIp[0].Addr(), 80, []byte("Hi from me")) + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + h := &header.H{} + err := h.Parse(p.Data) + if err != nil { + panic(err) + } + + // If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should + // not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now + if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskSending { + if h.Type == 0 && h.MessageCounter == 1 { + assert.NotContains(t, string(p.Data), "test me") + } + } + + if p.To == theirUdpAddr && h.Type == 1 { + return router.RouteAndExit + } + + return router.KeepRouting + }) + + t.Log("My cached packet should be received by them") + myCachedPacket := theirControl.GetFromTun(true) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), 80, 80) + + t.Log("Test the tunnel with them") + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl) + assertTunnel(t, myVpnIp[0].Addr(), theirVpnIp[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() + //TODO: assert hostmaps + }) + } + +} + //TODO: test // Race winner renews and handshakes // Race loser renews and handshakes diff --git a/e2e/router/router.go b/e2e/router/router.go index 5e52ed77c..8c166fe30 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -111,10 +111,6 @@ type ExitFunc func(packet *udp.Packet, receiver *nebula.Control) ExitType func NewR(t testing.TB, controls ...*nebula.Control) *R { ctx, cancel := context.WithCancel(context.Background()) - if err := os.MkdirAll("mermaid", 0755); err != nil { - panic(err) - } - r := &R{ controls: make(map[netip.AddrPort]*nebula.Control), vpnControls: make(map[netip.Addr]*nebula.Control), @@ -194,6 +190,9 @@ func (r *R) renderFlow() { return } + if err := os.MkdirAll(filepath.Dir(r.fn), 0755); err != nil { + panic(err) + } f, err := os.OpenFile(r.fn, os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0644) if err != nil { panic(err) diff --git a/examples/config.yml b/examples/config.yml index 1a312838b..679c4fc1b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -19,6 +19,38 @@ pki: # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. # default_version: 1 + # psk can be used to mask the contents of handshakes. + psk: + # `mode` defines how the pre shared keys can be used in a handshake. + # `accepting` (the default) will initiate handshakes using an empty key and will try to use any keys provided when + # receiving handshakes, including an empty key. + # `sending` will initiate handshakes with the first key provided and will try to use any keys provided when + # receiving handshakes, including an empty key. + # `enforced` will initiate handshakes with the first psk key provided and will try to use any keys provided when + # responding to handshakes. An empty key will not be allowed. + # + # To change a mesh from not using a psk to enforcing psk: + # 1. Leave `mode` as `accepting` and configure `psk.keys` to match on all nodes in the mesh and reload. + # 2. Change `mode` to `sending` on all nodes in the mesh and reload. + # 3. Change `mode` to `enforced` on all nodes in the mesh and reload. + #mode: accepting + + # The keys provided are sent through hkdf to ensure the shared secret used in the noise protocol is the + # correct byte length. + # + # Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on + # incoming handshakes. This is to allow for psk rotation. + # + # To rotate a primary key: + # 1. Put the new key in the 2nd slot on every node in the mesh and reload. + # 2. Move the key from the 2nd slot to the 1st slot, the old primary key is now in the 2nd slot, reload. + # 3. Remove the old primary key once it is no longer in use on every node in the mesh and reload. + #keys: + # - shared secret string, this one is used in all outbound handshakes # This is the primary key used when sending handshakes + # - this is a fallback key, received handshakes can use this + # - another fallback, received handshakes can use this one too + # - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire + # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: @@ -309,7 +341,6 @@ logging: # after receiving the response for lighthouse queries #trigger_buffer: 64 - # Nebula security group configuration firewall: # Action to take when a packet is not allowed by the firewall rules. diff --git a/handshake_ix.go b/handshake_ix.go index d1f3b5a9f..b06e4026d 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -50,7 +50,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { Error("Unable to handshake with host because no certificate handshake bytes is available") } - ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX, cs.psk.primary) if err != nil { f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). @@ -104,34 +104,53 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Error("Unable to handshake with host because no certificate is available") } - ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) - if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to create connection state") - return - } + var ( + err error + ci *ConnectionState + msg []byte + ) - // Mark packet 1 as seen so it doesn't show up as missed - ci.window.Update(f.l, 1) + hs := &NebulaHandshake{} - msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) - if err != nil { - f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed to call noise.ReadMessage") - return + for _, psk := range cs.psk.keys { + ci, err = NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX, psk) + if err != nil { + //TODO: should be bother logging this, if we have multiple psks and the error is unrelated it will be verbose. + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to create connection state") + continue + } + + msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:]) + if err != nil { + // Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one + continue + } + + // Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf + // comes out clean as well + err = hs.Unmarshal(msg) + if err == nil { + // There was no error, we can continue with this handshake + break + } + + // The unmarshal failed, try the next psk if we have one } - hs := &NebulaHandshake{} - err = hs.Unmarshal(msg) + // We finished with an error, log it and get out if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("udpAddr", addr). + // We aren't logging the error here because we can't be sure of the failure when using psk + f.l.WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - Error("Failed unmarshal handshake message") + Error("Was unable to decrypt the handshake") return } + // Mark packet 1 as seen so it doesn't show up as missed + ci.window.Update(f.l, 1) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). diff --git a/handshake_manager_test.go b/handshake_manager_test.go index 7edc55b9c..559eac5a9 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -10,6 +10,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_NewHandshakeManagerVpnIp(t *testing.T) { @@ -23,11 +24,15 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { lh := newTestLighthouse() + psk, err := NewPsk(PskAccepting, nil) + require.NoError(t, err) + cs := &CertState{ defaultVersion: cert.Version1, privateKey: []byte{}, v1Cert: &dummyCert{version: cert.Version1}, v1HandshakeBytes: []byte{}, + psk: psk, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) diff --git a/pki.go b/pki.go index 69c60863b..1282a4f8e 100644 --- a/pki.go +++ b/pki.go @@ -38,6 +38,8 @@ type CertState struct { pkcs11Backed bool cipher string + psk *Psk + myVpnNetworks []netip.Prefix myVpnNetworksTable *bart.Table[struct{}] myVpnAddrs []netip.Addr @@ -181,6 +183,16 @@ func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { } } + psk, err := NewPskFromConfig(c) + if err != nil { + return util.NewContextualError("Failed to load psk from config", nil, err) + } + if len(psk.keys) > 0 { + p.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.keys)). + Info("pre shared keys are in use") + } + newState.psk = psk + p.cs.Store(newState) //TODO: newState needs a stringer that does json diff --git a/psk.go b/psk.go new file mode 100644 index 000000000..987f5929d --- /dev/null +++ b/psk.go @@ -0,0 +1,150 @@ +package nebula + +import ( + "crypto/sha256" + "errors" + "fmt" + "io" + + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" + "golang.org/x/crypto/hkdf" +) + +var ErrNotAPskMode = errors.New("not a psk mode") +var ErrKeyTooShort = errors.New("key is too short") +var ErrNotEnoughPskKeys = errors.New("at least 1 key is required") + +// MinPskLength is the minimum bytes that we accept for a user defined psk, the choice is arbitrary +const MinPskLength = 8 + +type PskMode int + +const ( + PskAccepting PskMode = 0 + PskSending PskMode = 1 + PskEnforced PskMode = 2 +) + +func NewPskMode(m string) (PskMode, error) { + switch m { + case "accepting": + return PskAccepting, nil + case "sending": + return PskSending, nil + case "enforced": + return PskEnforced, nil + } + return PskAccepting, ErrNotAPskMode +} + +func (p PskMode) String() string { + switch p { + case PskAccepting: + return "accepting" + case PskSending: + return "sending" + case PskEnforced: + return "enforced" + } + + return "unknown" +} + +func (p PskMode) IsValid() bool { + switch p { + case PskAccepting, PskSending, PskEnforced: + return true + default: + return false + } +} + +type Psk struct { + // pskMode sets how psk works, ignored, allowed for incoming, or enforced for all + mode PskMode + + // primary is the key to use when sending, it may be nil + primary []byte + + // keys holds all pre-computed psk hkdfs + // Handshakes iterate this directly + keys [][]byte +} + +// NewPskFromConfig is a helper for initial boot and config reloading. +func NewPskFromConfig(c *config.C) (*Psk, error) { + sMode := c.GetString("psk.mode", "accepting") + mode, err := NewPskMode(sMode) + if err != nil { + return nil, util.NewContextualError("Could not parse psk.mode", m{"mode": mode}, err) + } + + return NewPsk( + mode, + c.GetStringSlice("psk.keys", nil), + ) +} + +// NewPsk creates a new Psk object and handles the caching of all accepted keys +func NewPsk(mode PskMode, keys []string) (*Psk, error) { + if !mode.IsValid() { + return nil, ErrNotAPskMode + } + + psk := &Psk{ + mode: mode, + } + + err := psk.cachePsks(keys) + if err != nil { + return nil, err + } + + return psk, nil +} + +// cachePsks generates all psks we accept and caches them to speed up handshaking +func (p *Psk) cachePsks(keys []string) error { + if p.mode != PskAccepting && len(keys) < 1 { + return ErrNotEnoughPskKeys + } + + p.keys = [][]byte{} + + for i, rk := range keys { + k, err := sha256KdfFromString(rk) + if err != nil { + return fmt.Errorf("failed to generate key for position %v: %w", i, err) + } + + p.keys = append(p.keys, k) + } + + if p.mode != PskAccepting { + // We are either sending or enforcing, the primary key must the first slot + p.primary = p.keys[0] + } + + if p.mode != PskEnforced { + // If we are not enforcing psk use then a nil psk is allowed + p.keys = append(p.keys, nil) + } + + return nil +} + +// sha256KdfFromString generates a useful key to use from a provided secret +func sha256KdfFromString(secret string) ([]byte, error) { + if len(secret) < MinPskLength { + return nil, ErrKeyTooShort + } + + hmacKey := make([]byte, sha256.Size) + _, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, nil), hmacKey) + if err != nil { + return nil, err + } + + return hmacKey, nil +} diff --git a/psk_test.go b/psk_test.go new file mode 100644 index 000000000..924079549 --- /dev/null +++ b/psk_test.go @@ -0,0 +1,71 @@ +package nebula + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewPsk(t *testing.T) { + t.Run("mode accepting", func(t *testing.T) { + p, err := NewPsk(PskAccepting, nil) + assert.NoError(t, err) + assert.Equal(t, PskAccepting, p.mode) + assert.Nil(t, p.keys[0]) + assert.Nil(t, p.primary) + + p, err = NewPsk(PskAccepting, []string{"1234567"}) + assert.Error(t, ErrKeyTooShort) + + p, err = NewPsk(PskAccepting, []string{"hi there friends"}) + assert.NoError(t, err) + assert.Equal(t, PskAccepting, p.mode) + assert.Nil(t, p.primary) + assert.Len(t, p.keys, 2) + assert.Nil(t, p.keys[1]) + + expectedCache := []byte{ + 0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65, + 0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef, + } + assert.Equal(t, expectedCache, p.keys[0]) + }) + + t.Run("mode sending", func(t *testing.T) { + p, err := NewPsk(PskSending, nil) + assert.Error(t, ErrNotEnoughPskKeys, err) + + p, err = NewPsk(PskSending, []string{"1234567"}) + assert.Error(t, ErrKeyTooShort) + + p, err = NewPsk(PskSending, []string{"hi there friends"}) + assert.NoError(t, err) + assert.Equal(t, PskSending, p.mode) + assert.Len(t, p.keys, 2) + assert.Nil(t, p.keys[1]) + + expectedCache := []byte{ + 0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65, + 0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef, + } + assert.Equal(t, expectedCache, p.keys[0]) + assert.Equal(t, p.keys[0], p.primary) + }) + + t.Run("mode enforced", func(t *testing.T) { + p, err := NewPsk(PskEnforced, nil) + assert.Error(t, ErrNotEnoughPskKeys, err) + + p, err = NewPsk(PskEnforced, []string{"hi there friends"}) + assert.NoError(t, err) + assert.Equal(t, PskEnforced, p.mode) + assert.Len(t, p.keys, 1) + + expectedCache := []byte{ + 0xb9, 0x8c, 0xdc, 0xac, 0x77, 0xf4, 0x8c, 0xf8, 0x1d, 0xe7, 0xe7, 0xb, 0x53, 0x25, 0xd3, 0x65, + 0xa3, 0x9f, 0x78, 0xb2, 0xc7, 0x2d, 0xa5, 0xd8, 0x84, 0x81, 0x7b, 0xb5, 0xdb, 0xe0, 0x9a, 0xef, + } + assert.Equal(t, expectedCache, p.keys[0]) + assert.Equal(t, p.keys[0], p.primary) + }) +}