Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cert v2 + tun changes for Linux #1224

Merged
merged 17 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ type CachedCertificate struct {
func UnmarshalCertificate(b []byte) (Certificate, error) {
//TODO: you left off here, no one uses this function but it might be beneficial to export _something_ that someone can use, maybe the Versioned unmarshallsers?
var c Certificate
c, err := unmarshalCertificateV2(b, nil)
c, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
if err == nil {
return c, nil
}
Expand All @@ -129,15 +129,15 @@ func UnmarshalCertificate(b []byte) (Certificate, error) {
// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake.
// Handshakes save space by placing the peers public key in a different part of the packet, we have to
// reassemble the actual certificate structure with that in mind.
func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (Certificate, error) {
func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) {
var c Certificate
var err error

switch v {
case VersionPre1, Version1:
c, err = unmarshalCertificateV1(b, publicKey)
case Version2:
c, err = unmarshalCertificateV2(b, publicKey)
c, err = unmarshalCertificateV2(b, publicKey, curve)
default:
//TODO: make a static var
return nil, fmt.Errorf("unknown certificate version %d", v)
Expand All @@ -146,10 +146,15 @@ func UnmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte) (C
if err != nil {
return nil, err
}

if c.Curve() != curve {
return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String())
}

return c, nil
}

func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAPool) (*CachedCertificate, error) {
func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) {
if publicKey == nil {
return nil, ErrNoPeerStaticKey
}
Expand All @@ -158,7 +163,7 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, caPool *CAP
return nil, ErrNoPayload
}

c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey)
c, err := UnmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve)
if err != nil {
return nil, fmt.Errorf("error unmarshaling cert: %w", err)
}
Expand Down
4 changes: 1 addition & 3 deletions cert/cert_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"fmt"
"net"
"net/netip"
"slices"
"time"

"golang.org/x/crypto/curve25519"
Expand Down Expand Up @@ -393,8 +392,7 @@ func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error)
}
}

slices.SortFunc(nc.details.networks, comparePrefix)
slices.SortFunc(nc.details.unsafeNetworks, comparePrefix)
//do not sort the subnets field for V1 certs

return &nc, nil
}
Expand Down
7 changes: 4 additions & 3 deletions cert/cert_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ func (d *detailsV2) Marshal() ([]byte, error) {
return b.Bytes()
}

func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error) {
func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) {
l := len(b)
if l == 0 || l > MaxCertificateSize {
return nil, ErrBadFormat
Expand All @@ -473,11 +473,12 @@ func unmarshalCertificateV2(b []byte, publicKey []byte) (*certificateV2, error)
return nil, ErrBadFormat
}

//Maybe grab the curve
var rawCurve byte
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(Curve_CURVE25519)) {
if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) {
return nil, ErrBadFormat
}
curve := Curve(rawCurve)
curve = Curve(rawCurve)

// Maybe grab the public key
var rawPublicKey cryptobyte.String
Expand Down
2 changes: 1 addition & 1 deletion cert/pem.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) {
case CertificateBanner:
c, err = unmarshalCertificateV1(p.Bytes, nil)
case CertificateV2Banner:
c, err = unmarshalCertificateV2(p.Bytes, nil)
c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519)
default:
return nil, r, ErrInvalidPEMCertificateBanner
}
Expand Down
4 changes: 4 additions & 0 deletions connection_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
"message_counter": cs.messageCounter.Load(),
})
}

func (cs *ConnectionState) Curve() cert.Curve {
return cs.myCert.Curve()
}
9 changes: 9 additions & 0 deletions dns_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeAAAA:
l.Debugf("Query for AAAA %s", q.Name)
ip := dnsR.Query(q.Name)
if ip != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
case dns.TypeTXT:
a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
b, err := netip.ParseAddr(a)
Expand Down
4 changes: 2 additions & 2 deletions handshake_ix.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet
return
}

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"})
Expand Down Expand Up @@ -404,7 +404,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha
return true
}

remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), f.pki.GetCAPool())
remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool())
if err != nil {
e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr).
WithField("handshake", m{"stage": 2, "style": "ix_psk0"})
Expand Down
3 changes: 3 additions & 0 deletions lighthouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,9 @@ func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnAdd
useVersion = 2
}

//todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They should default to protocol v2, are you seeing something I'm missing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

protocol v2 is working, this is actually a multi-IP issue in a funny hat.

If you have a LH at 10.0.0.1 and you contact it via 10.0.0.2 (who is also fc00:02), because we try to use the IP in the message instead of the hostmap, the LH will never learn the underlay IP for fc00:02, unless you also contact it via overlay-ipv6.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless of course, I have something backwards

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. Yeah this and the hostmap don't understand follow on addresses yet. We need a final loop on the updates to point all subsequent addresses to the primary address. I have some of this code staged but its not ready just yet.

//todo why do we care about the vpnip in the packet? We know where it came from, right?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a great question and I have it in my internals notes as well as ditching the V4AddrPort in v2 protocol stuff.


if detailsVpnIp != vpnAddrs[0] {
if lhh.l.Level >= logrus.DebugLevel {
lhh.l.WithField("vpnAddrs", vpnAddrs).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update")
Expand Down
111 changes: 101 additions & 10 deletions outside.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"net/netip"
"time"

"github.com/google/gopacket/layers"
"golang.org/x/net/ipv6"

"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/firewall"
"github.com/slackhq/nebula/header"
Expand Down Expand Up @@ -297,22 +300,112 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h

// newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
if len(data) < 1 {
return errors.New("packet too short")
}

version := int((data[0] >> 4) & 0x0f)
switch version {
case ipv4.Version:
return parseV4(data, incoming, fp)
case ipv6.Version:
return parseV6(data, incoming, fp)
}
return fmt.Errorf("packet is an unknown ip version: %v", version)
}

func parseV6(data []byte, incoming bool, fp *firewall.Packet) error {
dataLen := len(data)
if dataLen < ipv6.HeaderLen {
return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen)
}

if incoming {
fp.RemoteIP, _ = netip.AddrFromSlice(data[8:24])
fp.LocalIP, _ = netip.AddrFromSlice(data[24:40])
} else {
fp.LocalIP, _ = netip.AddrFromSlice(data[8:24])
fp.RemoteIP, _ = netip.AddrFromSlice(data[24:40])
}

//TODO: whats a reasonable number of extension headers to attempt to parse?
//https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html
protoAt := 6
offset := 40
for i := 0; i < 24; i++ {
if dataLen < offset {
break
}

proto := layers.IPProtocol(data[protoAt])
//fmt.Println(proto, protoAt)
switch proto {
case layers.IPProtocolICMPv6:
//TODO: we need a new protocol in config language "icmpv6"
fp.Protocol = uint8(proto)
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = false
return nil

// Is it an ipv4 packet?
if int((data[0]>>4)&0x0f) != 4 {
return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
case layers.IPProtocolTCP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
fp.Fragment = false
return nil

case layers.IPProtocolUDP:
if dataLen < offset+4 {
return fmt.Errorf("ipv6 packet was too small")
}
fp.Protocol = uint8(proto)
fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2])
fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4])
fp.Fragment = false
return nil

case layers.IPProtocolIPv6Fragment:
//TODO: can we determine the protocol?
fp.RemotePort = 0
fp.LocalPort = 0
fp.Fragment = true
return nil

default:
if dataLen < offset+1 {
break
}

next := int(data[offset+1]) * 8
if next == 0 {
// each extension is at least 8 bytes
next = 8
}

protoAt = offset
offset = offset + next
}
}

return fmt.Errorf("could not find payload in ipv6 packet")
}

func parseV4(data []byte, incoming bool, fp *firewall.Packet) error {
// Do we at least have an ipv4 header worth of data?
if len(data) < ipv4.HeaderLen {
return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen)
}

// Adjust our start position based on the advertised ip header length
ihl := int(data[0]&0x0f) << 2

// Well formed ip header length?
if ihl < ipv4.HeaderLen {
return fmt.Errorf("packet had an invalid header length: %v", ihl)
return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl)
}

// Check if this is the second or further fragment of a fragmented packet.
Expand All @@ -328,12 +421,11 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
minLen += minFwPacketLen
}
if len(data) < minLen {
return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl)
}

// Firewall packets are locally oriented
if incoming {
//TODO: IPV6-WORK
fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16])
fp.LocalIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
Expand All @@ -344,7 +436,6 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
}
} else {
//TODO: IPV6-WORK
fp.LocalIP, _ = netip.AddrFromSlice(data[12:16])
fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20])
if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
Expand Down
18 changes: 12 additions & 6 deletions outside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,15 @@ import (
func Test_newPacket(t *testing.T) {
p := &firewall.Packet{}

// length fail
err := newPacket([]byte{0, 1}, true, p)
assert.EqualError(t, err, "packet is less than 20 bytes")
// length fails
err := newPacket([]byte{}, true, p)
assert.EqualError(t, err, "packet too short")

err = newPacket([]byte{0x40}, true, p)
assert.EqualError(t, err, "ipv4 packet is less than 20 bytes")

err = newPacket([]byte{0x60}, true, p)
assert.EqualError(t, err, "ipv6 packet is less than 20 bytes")

// length fail with ip options
h := ipv4.Header{
Expand All @@ -29,15 +35,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal()
err = newPacket(b, true, p)

assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24")
assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24")

// not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet is not ipv4, type: 0")
assert.EqualError(t, err, "packet is an unknown ip version: 0")

// invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.EqualError(t, err, "packet had an invalid header length: 8")
assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8")

// account for variable ip header length - incoming
h = ipv4.Header{
Expand Down
Loading
Loading