diff --git a/.gitignore b/.gitignore index 0bffc857c..55068f31b 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,8 @@ /nebula-darwin /nebula.exe /nebula-cert.exe -/coverage.out +**/coverage.out +**/cover.out /cpu.pprof /build /*.tar.gz diff --git a/Makefile b/Makefile index 6922cc3e8..d3fbcaa57 100644 --- a/Makefile +++ b/Makefile @@ -196,7 +196,7 @@ bench-cpu-long: go test -bench=. -benchtime=60s -cpuprofile=cpu.pprof go tool pprof go-audit.test cpu.pprof -proto: nebula.pb.go cert/cert.pb.go +proto: nebula.pb.go cert/cert_v1.pb.go nebula.pb.go: nebula.proto .FORCE go build github.com/gogo/protobuf/protoc-gen-gogofaster diff --git a/calculated_remote.go b/calculated_remote.go index ae2ed500c..32d062a83 100644 --- a/calculated_remote.go +++ b/calculated_remote.go @@ -21,7 +21,11 @@ type calculatedRemote struct { port uint32 } -func newCalculatedRemote(maskCidr netip.Prefix, port int) (*calculatedRemote, error) { +func newCalculatedRemote(cidr, maskCidr netip.Prefix, port int) (*calculatedRemote, error) { + if maskCidr.Addr().BitLen() != cidr.Addr().BitLen() { + return nil, fmt.Errorf("invalid mask: %s for cidr: %s", maskCidr, cidr) + } + masked := maskCidr.Masked() if port < 0 || port > math.MaxUint16 { return nil, fmt.Errorf("invalid port: %d", port) @@ -38,32 +42,38 @@ func (c *calculatedRemote) String() string { return fmt.Sprintf("CalculatedRemote(mask=%v port=%d)", c.ipNet, c.port) } -func (c *calculatedRemote) Apply(ip netip.Addr) *Ip4AndPort { - // Combine the masked bytes of the "mask" IP with the unmasked bytes - // of the overlay IP - if c.ipNet.Addr().Is4() { - return c.apply4(ip) - } - return c.apply6(ip) -} - -func (c *calculatedRemote) apply4(ip netip.Addr) *Ip4AndPort { - //TODO: IPV6-WORK this can be less crappy +func (c *calculatedRemote) ApplyV4(addr netip.Addr) *V4AddrPort { + // Combine the masked bytes of the "mask" IP with the unmasked bytes of the overlay IP maskb := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) mask := binary.BigEndian.Uint32(maskb[:]) b := c.mask.Addr().As4() - maskIp := binary.BigEndian.Uint32(b[:]) + maskAddr := binary.BigEndian.Uint32(b[:]) - b = ip.As4() - intIp := binary.BigEndian.Uint32(b[:]) + b = addr.As4() + intAddr := binary.BigEndian.Uint32(b[:]) - return &Ip4AndPort{(maskIp & mask) | (intIp & ^mask), c.port} + return &V4AddrPort{(maskAddr & mask) | (intAddr & ^mask), c.port} } -func (c *calculatedRemote) apply6(ip netip.Addr) *Ip4AndPort { - //TODO: IPV6-WORK - panic("Can not calculate ipv6 remote addresses") +func (c *calculatedRemote) ApplyV6(addr netip.Addr) *V6AddrPort { + mask := net.CIDRMask(c.mask.Bits(), c.mask.Addr().BitLen()) + maskAddr := c.mask.Addr().As16() + calcAddr := addr.As16() + + ap := V6AddrPort{Port: c.port} + + maskb := binary.BigEndian.Uint64(mask[:8]) + maskAddrb := binary.BigEndian.Uint64(maskAddr[:8]) + calcAddrb := binary.BigEndian.Uint64(calcAddr[:8]) + ap.Hi = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + maskb = binary.BigEndian.Uint64(mask[8:]) + maskAddrb = binary.BigEndian.Uint64(maskAddr[8:]) + calcAddrb = binary.BigEndian.Uint64(calcAddr[8:]) + ap.Lo = (maskAddrb & maskb) | (calcAddrb & ^maskb) + + return &ap } func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calculatedRemote], error) { @@ -89,8 +99,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu return nil, fmt.Errorf("config `%s` has invalid CIDR: %s", k, rawCIDR) } - //TODO: IPV6-WORK this does not verify that rawValue contains the same bits as cidr here - entry, err := newCalculatedRemotesListFromConfig(rawValue) + entry, err := newCalculatedRemotesListFromConfig(cidr, rawValue) if err != nil { return nil, fmt.Errorf("config '%s.%s': %w", k, rawCIDR, err) } @@ -101,7 +110,7 @@ func NewCalculatedRemotesFromConfig(c *config.C, k string) (*bart.Table[[]*calcu return calculatedRemotes, nil } -func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { +func newCalculatedRemotesListFromConfig(cidr netip.Prefix, raw any) ([]*calculatedRemote, error) { rawList, ok := raw.([]any) if !ok { return nil, fmt.Errorf("calculated_remotes entry has invalid type: %T", raw) @@ -109,7 +118,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { var l []*calculatedRemote for _, e := range rawList { - c, err := newCalculatedRemotesEntryFromConfig(e) + c, err := newCalculatedRemotesEntryFromConfig(cidr, e) if err != nil { return nil, fmt.Errorf("calculated_remotes entry: %w", err) } @@ -119,7 +128,7 @@ func newCalculatedRemotesListFromConfig(raw any) ([]*calculatedRemote, error) { return l, nil } -func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { +func newCalculatedRemotesEntryFromConfig(cidr netip.Prefix, raw any) (*calculatedRemote, error) { rawMap, ok := raw.(map[any]any) if !ok { return nil, fmt.Errorf("invalid type: %T", raw) @@ -155,5 +164,5 @@ func newCalculatedRemotesEntryFromConfig(raw any) (*calculatedRemote, error) { return nil, fmt.Errorf("invalid port (type %T): %v", rawValue, rawValue) } - return newCalculatedRemote(maskCidr, port) + return newCalculatedRemote(cidr, maskCidr, port) } diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 6ff1cb0bd..066213eaf 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -9,10 +9,9 @@ import ( ) func TestCalculatedRemoteApply(t *testing.T) { - ipNet, err := netip.ParsePrefix("192.168.1.0/24") - require.NoError(t, err) - - c, err := newCalculatedRemote(ipNet, 4242) + // Test v4 addresses + ipNet := netip.MustParsePrefix("192.168.1.0/24") + c, err := newCalculatedRemote(ipNet, ipNet, 4242) require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") @@ -21,5 +20,62 @@ func TestCalculatedRemoteApply(t *testing.T) { expected, err := netip.ParseAddr("192.168.1.182") assert.NoError(t, err) - assert.Equal(t, NewIp4AndPortFromNetIP(expected, 4242), c.Apply(input)) + assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) + + // Test v6 addresses + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff::0/64") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff:ffff:ffff::0/80") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) + + // Test v6 addresses part 2 + ipNet = netip.MustParsePrefix("ffff:ffff:ffff::0/48") + c, err = newCalculatedRemote(ipNet, ipNet, 4242) + require.NoError(t, err) + + input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") + assert.NoError(t, err) + + assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) +} + +func Test_newCalculatedRemote(t *testing.T) { + c, err := newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.EqualError(t, err, "invalid mask: 1.0.0.0/32 for cidr: 1::1/128") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1::1/128"), 4242) + require.EqualError(t, err, "invalid mask: 1::1/128 for cidr: 1.0.0.0/32") + require.Nil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1.0.0.0/32"), netip.MustParsePrefix("1.0.0.0/32"), 4242) + require.NoError(t, err) + require.NotNil(t, c) + + c, err = newCalculatedRemote(netip.MustParsePrefix("1::1/128"), netip.MustParsePrefix("1::1/128"), 4242) + require.NoError(t, err) + require.NotNil(t, c) } diff --git a/cert/README.md b/cert/README.md index ae19a28f5..1e27a6bf5 100644 --- a/cert/README.md +++ b/cert/README.md @@ -2,14 +2,25 @@ This is a library for interacting with `nebula` style certificates and authorities. -A `protobuf` definition of the certificate format is also included +There are now 2 versions of `nebula` certificates: -### Compiling the protobuf definition +## v1 -Make sure you have `protoc` installed. +This version is deprecated. + +A `protobuf` definition of the certificate format is included at `cert_v1.proto` + +To compile the definition you will need `protoc` installed. To compile for `go` with the same version of protobuf specified in go.mod: ```bash -make +make proto ``` + +## v2 + +This is the latest version which uses asn.1 DER encoding. It can support ipv4 and ipv6 and tolerate +future certificate changes better than v1. + +`cert_v2.asn1` defines the wire format and can be used to compile marshalers. \ No newline at end of file diff --git a/cert/asn1.go b/cert/asn1.go new file mode 100644 index 000000000..6bf6a8ded --- /dev/null +++ b/cert/asn1.go @@ -0,0 +1,52 @@ +package cert + +import ( + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +// readOptionalASN1Boolean reads an asn.1 boolean with a specific tag instead of a asn.1 tag wrapping a boolean with a value +// https://github.com/golang/go/issues/64811#issuecomment-1944446920 +func readOptionalASN1Boolean(b *cryptobyte.String, out *bool, tag asn1.Tag, defaultValue bool) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] > 0 + return true + } + + return false +} + +// readOptionalASN1Byte reads an asn.1 uint8 with a specific tag instead of a asn.1 tag wrapping a uint8 with a value +// Similar issue as with readOptionalASN1Boolean +func readOptionalASN1Byte(b *cryptobyte.String, out *byte, tag asn1.Tag, defaultValue byte) bool { + var present bool + var child cryptobyte.String + if !b.ReadOptionalASN1(&child, &present, tag) { + return false + } + + if !present { + *out = defaultValue + return true + } + + // Ensure we have 1 byte + if len(child) == 1 { + *out = child[0] + return true + } + + return false +} diff --git a/cert/ca_pool.go b/cert/ca_pool.go index d52583035..2bf480f22 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -213,7 +213,7 @@ func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { return signer, nil } - return nil, fmt.Errorf("could not find ca for the certificate") + return nil, ErrCaNotFound } // GetFingerprints returns an array of trusted CA fingerprints diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index 053640d98..f03b2ba83 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -1,7 +1,9 @@ package cert import ( + "net/netip" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -10,15 +12,15 @@ func TestNewCAPoolFromBytes(t *testing.T) { noNewLines := ` # Current provisional, Remove once everything moves over to the real root. -----BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== -----END NEBULA CERTIFICATE----- # root-ca01 -----BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== -----END NEBULA CERTIFICATE----- ` @@ -26,18 +28,18 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf # Current provisional, Remove once everything moves over to the real root. -----BEGIN NEBULA CERTIFICATE----- -CkAKDm5lYnVsYSByb290IGNhKJfap9AFMJfg1+YGOiCUQGByMuNRhIlQBOyzXWbL -vcKBwDhov900phEfJ5DN3kABEkDCq5R8qBiu8sl54yVfgRcQXEDt3cHr8UTSLszv -bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB +Cj4KDm5lYnVsYSByb290IGNhKM0cMM24zPCvBzogV24YEw5YiqeI/oYo8XXFsoo+ +PBmiOafNJhLacf9rsspAARJAz9OAnh8TKAUKix1kKVMyQU4iM3LsFfZRf6ODWXIf +2qWMpB6fpd3PSoVYziPoOt2bIHIFLlgRLPJz3I3xBEdBCQ== -----END NEBULA CERTIFICATE----- # root-ca01 -----BEGIN NEBULA CERTIFICATE----- -CkMKEW5lYnVsYSByb290IGNhIDAxKJL2u9EFMJL86+cGOiDPXMH4oU6HZTk/CqTG -BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf -8/phAUt+FLzqTECzQKisYswKvE3pl9mbEYKbOdIHrxdIp95mo4sF +CkEKEW5lYnVsYSByb290IGNhIDAxKM0cMM24zPCvBzogPzbWTxt8ZgXPQEwup7Br +BrtIt1O0q5AuTRT3+t2x1VJAARJAZ+2ib23qBXjdy49oU1YysrwuKkWWKrtJ7Jye +rFBQpDXikOukhQD/mfkloFwJ+Yjsfru7IpTN4ZfjXL+kN/2sCA== -----END NEBULA CERTIFICATE----- ` @@ -45,65 +47,513 @@ BVG+oJpAoqokUBbI4U0N8CSfpUABEkB/Pm5A2xyH/nc8mg/wvGUWG3pZ7nHzaDMf expired := ` # expired certificate -----BEGIN NEBULA CERTIFICATE----- -CjkKB2V4cGlyZWQouPmWjQYwufmWjQY6ILCRaoCkJlqHgv5jfDN4lzLHBvDzaQm4 -vZxfu144hmgjQAESQG4qlnZi8DncvD/LDZnLgJHOaX1DWCHHEh59epVsC+BNgTie -WH1M9n4O7cFtGlM6sJJOS+rCVVEJ3ABS7+MPdQs= +CjMKB2V4cGlyZWQozRwwzRw6ICJSG94CqX8wn5I65Pwn25V6HftVfWeIySVtp2DA +7TY/QAESQMaAk5iJT5EnQwK524ZaaHGEJLUqqbh5yyOHhboIGiVTWkFeH3HccTW8 +Tq5a8AyWDQdfXbtEZ1FwabeHfH5Asw0= -----END NEBULA CERTIFICATE----- ` p256 := ` # p256 certificate -----BEGIN NEBULA CERTIFICATE----- -CmYKEG5lYnVsYSBQMjU2IHRlc3Qo4s+7mgYw4tXrsAc6QQRkaW2jFmllYvN4+/k2 -6tctO9sPT3jOx8ES6M1nIqOhpTmZeabF/4rELDqPV4aH5jfJut798DUXql0FlF8H -76gvQAGgBgESRzBFAiEAib0/te6eMiZOKD8gdDeloMTS0wGuX2t0C7TFdUhAQzgC -IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX +CmQKEG5lYnVsYSBQMjU2IHRlc3QozRwwzbjM8K8HOkEEdrmmg40zQp44AkMq6DZp +k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe ++0ABoAYBEkcwRQIgVoTg38L7uWku9xQgsr06kxZ/viQLOO/w1Qj1vFUEnhcCIQCq +75SjTiV92kv/1GcbT3wWpAZQQDBiUHVMVmh1822szA== -----END NEBULA CERTIFICATE----- ` rootCA := certificateV1{ details: detailsV1{ - Name: "nebula root ca", + name: "nebula root ca", }, } rootCA01 := certificateV1{ details: detailsV1{ - Name: "nebula root ca 01", + name: "nebula root ca 01", }, } rootCAP256 := certificateV1{ details: detailsV1{ - Name: "nebula P256 test", + name: "nebula P256 test", }, } p, err := NewCAPoolFromPEM([]byte(noNewLines)) assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") + assert.Equal(t, ppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) - assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") + assert.Equal(t, pppp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) + assert.Equal(t, pppp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) + assert.Equal(t, pppp.CAs["c39b35a0e8f246203fe4f32b9aa8bfd155f1ae6a6be9d78370641e43397f48f5"].Certificate.Name(), "expired") assert.Equal(t, len(pppp.CAs), 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) + assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Equal(t, len(ppppp.CAs), 1) } + +func TestCertificateV1_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV1_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_VerifyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + + caPool := NewCAPool() + assert.NoError(t, caPool.AddCA(ca)) + + f, err := c.Fingerprint() + assert.Nil(t, err) + caPool.BlocklistFingerprint(f) + + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.EqualError(t, err, "certificate is in the block list") + + caPool.ResetCertBlocklist() + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) + assert.EqualError(t, err, "root certificate is expired") + + assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + }) + + // Test group assertion + ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool = NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { + NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) + }) + + c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify_IPs(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} + +func TestCertificateV2_Verify_Subnets(t *testing.T) { + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + + caPem, err := ca.MarshalPEM() + assert.Nil(t, err) + + caPool := NewCAPool() + b, err := caPool.AddCAFromPEM(caPem) + assert.NoError(t, err) + assert.Empty(t, b) + + // ip is outside the network + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is outside the network reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip is within the network but mask is outside reversed order of above + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + assert.PanicsWithError(t, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15", func() { + NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + }) + + // ip and mask are within the network + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) + + // Exact matches reversed with just 1 + c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) + assert.Nil(t, err) + _, err = caPool.VerifyCertificate(time.Now(), c) + assert.Nil(t, err) +} diff --git a/cert/cert.go b/cert/cert.go index 02c88777c..d11025843 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -1,15 +1,17 @@ package cert import ( + "fmt" "net/netip" "time" ) -type Version int +type Version uint8 const ( - Version1 Version = 1 - Version2 Version = 2 + VersionPre1 Version = 0 + Version1 Version = 1 + Version2 Version = 2 ) type Certificate interface { @@ -107,23 +109,56 @@ type CachedCertificate struct { signerFingerprint string } -// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. -func UnmarshalCertificate(b []byte) (Certificate, error) { - c, err := unmarshalCertificateV1(b, true) - if err != nil { - return nil, err - } - return c, nil +func (cc *CachedCertificate) String() string { + return cc.Certificate.String() } -// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. +// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { - c, err := unmarshalCertificateV1(b, false) +func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { + if publicKey == nil { + return nil, ErrNoPeerStaticKey + } + + if rawCertBytes == nil { + return nil, ErrNoPayload + } + + c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) + if err != nil { + return nil, fmt.Errorf("error unmarshaling cert: %w", err) + } + + cc, err := caPool.VerifyCertificate(time.Now(), c) + if err != nil { + return nil, fmt.Errorf("certificate validation failed: %w", err) + } + + return cc, nil +} + +func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { + var c Certificate + var err error + + switch v { + case VersionPre1, Version1: + c, err = unmarshalCertificateV1(b, publicKey) + case Version2: + c, err = unmarshalCertificateV2(b, publicKey, curve) + default: + //TODO: make a static var + return nil, fmt.Errorf("unknown certificate version %d", v) + } + if err != nil { return nil, err } - c.details.PublicKey = publicKey + + if c.Curve() != curve { + return nil, fmt.Errorf("certificate curve %s does not match expected %s", c.Curve().String(), curve.String()) + } + return c, nil } diff --git a/cert/cert_test.go b/cert/cert_test.go deleted file mode 100644 index 12bbd9700..000000000 --- a/cert/cert_test.go +++ /dev/null @@ -1,695 +0,0 @@ -package cert - -import ( - "crypto/ecdh" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "fmt" - "io" - "net/netip" - "testing" - "time" - - "github.com/slackhq/nebula/test" - "github.com/stretchr/testify/assert" - "golang.org/x/crypto/curve25519" - "golang.org/x/crypto/ed25519" -) - -func TestMarshalingNebulaCertificate(t *testing.T) { - before := time.Now().Add(time.Second * -60).Round(time.Second) - after := time.Now().Add(time.Second * 60).Round(time.Second) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := certificateV1{ - details: detailsV1{ - Name: "testing", - Ips: []netip.Prefix{ - mustParsePrefixUnmapped("10.1.1.1/24"), - mustParsePrefixUnmapped("10.1.1.2/16"), - }, - Subnets: []netip.Prefix{ - mustParsePrefixUnmapped("9.1.1.2/24"), - mustParsePrefixUnmapped("9.1.1.3/16"), - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - - nc2, err := unmarshalCertificateV1(b, true) - assert.Nil(t, err) - - assert.Equal(t, nc.signature, nc2.Signature()) - assert.Equal(t, nc.details.Name, nc2.Name()) - assert.Equal(t, nc.details.NotBefore, nc2.NotBefore()) - assert.Equal(t, nc.details.NotAfter, nc2.NotAfter()) - assert.Equal(t, nc.details.PublicKey, nc2.PublicKey()) - assert.Equal(t, nc.details.IsCA, nc2.IsCA()) - - assert.Equal(t, nc.details.Ips, nc2.Networks()) - assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks()) - - assert.Equal(t, nc.details.Groups, nc2.Groups()) -} - -//func TestNebulaCertificate_Sign(t *testing.T) { -// before := time.Now().Add(time.Second * -60).Round(time.Second) -// after := time.Now().Add(time.Second * 60).Round(time.Second) -// pubKey := []byte("1234567890abcedfghij1234567890ab") -// -// nc := certificateV1{ -// details: detailsV1{ -// Name: "testing", -// Ips: []netip.Prefix{ -// mustParsePrefixUnmapped("10.1.1.1/24"), -// mustParsePrefixUnmapped("10.1.1.2/16"), -// //TODO: netip cant do it -// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// }, -// Subnets: []netip.Prefix{ -// //TODO: netip cant do it -// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// mustParsePrefixUnmapped("9.1.1.2/24"), -// mustParsePrefixUnmapped("9.1.1.3/24"), -// }, -// Groups: []string{"test-group1", "test-group2", "test-group3"}, -// NotBefore: before, -// NotAfter: after, -// PublicKey: pubKey, -// IsCA: false, -// Issuer: "1234567890abcedfghij1234567890ab", -// }, -// } -// -// pub, priv, err := ed25519.GenerateKey(rand.Reader) -// assert.Nil(t, err) -// assert.False(t, nc.CheckSignature(pub)) -// assert.Nil(t, nc.Sign(Curve_CURVE25519, priv)) -// assert.True(t, nc.CheckSignature(pub)) -// -// _, err = nc.Marshal() -// assert.Nil(t, err) -// //t.Log("Cert size:", len(b)) -//} - -//func TestNebulaCertificate_SignP256(t *testing.T) { -// before := time.Now().Add(time.Second * -60).Round(time.Second) -// after := time.Now().Add(time.Second * 60).Round(time.Second) -// pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") -// -// nc := certificateV1{ -// details: detailsV1{ -// Name: "testing", -// Ips: []netip.Prefix{ -// mustParsePrefixUnmapped("10.1.1.1/24"), -// mustParsePrefixUnmapped("10.1.1.2/16"), -// //TODO: netip no can do -// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// }, -// Subnets: []netip.Prefix{ -// //TODO: netip bad -// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// mustParsePrefixUnmapped("9.1.1.2/24"), -// mustParsePrefixUnmapped("9.1.1.3/16"), -// }, -// Groups: []string{"test-group1", "test-group2", "test-group3"}, -// NotBefore: before, -// NotAfter: after, -// PublicKey: pubKey, -// IsCA: false, -// Curve: Curve_P256, -// Issuer: "1234567890abcedfghij1234567890ab", -// }, -// } -// -// priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) -// pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) -// rawPriv := priv.D.FillBytes(make([]byte, 32)) -// -// assert.Nil(t, err) -// assert.False(t, nc.CheckSignature(pub)) -// assert.Nil(t, nc.Sign(Curve_P256, rawPriv)) -// assert.True(t, nc.CheckSignature(pub)) -// -// _, err = nc.Marshal() -// assert.Nil(t, err) -// //t.Log("Cert size:", len(b)) -//} - -func TestNebulaCertificate_Expired(t *testing.T) { - nc := certificateV1{ - details: detailsV1{ - NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), - NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), - }, - } - - assert.True(t, nc.Expired(time.Now().Add(time.Hour))) - assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) - assert.False(t, nc.Expired(time.Now())) -} - -func TestNebulaCertificate_MarshalJSON(t *testing.T) { - time.Local = time.UTC - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := certificateV1{ - details: detailsV1{ - Name: "testing", - Ips: []netip.Prefix{ - mustParsePrefixUnmapped("10.1.1.1/24"), - mustParsePrefixUnmapped("10.1.1.2/16"), - }, - Subnets: []netip.Prefix{ - mustParsePrefixUnmapped("9.1.1.2/24"), - mustParsePrefixUnmapped("9.1.1.3/16"), - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), - NotAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.MarshalJSON() - assert.Nil(t, err) - assert.Equal( - t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", - string(b), - ) -} - -func TestNebulaCertificate_Verify(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) - - f, err := c.Fingerprint() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - assert.EqualError(t, err, "certificate is valid before the signing certificate") - - // Test group assertion - ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) - - f, err := c.Fingerprint() - assert.Nil(t, err) - caPool.BlocklistFingerprint(f) - - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") - - caPool.ResetCertBlocklist() - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") - - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - assert.EqualError(t, err, "certificate is valid before the signing certificate") - - // Test group assertion - ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool = NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) - assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_IPs(t *testing.T) { - caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") - caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - // ip is outside the network - cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") - cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.EqualError(t, err, "certificate contained a network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_Verify_Subnets(t *testing.T) { - caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") - caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - - caPem, err := ca.MarshalPEM() - assert.Nil(t, err) - - caPool := NewCAPool() - b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) - assert.Empty(t, b) - - // ip is outside the network - cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") - cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is outside the network reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.1.0.0/24") - - // ip is within the network but mask is outside - cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip is within the network but mask is outside reversed order of above - cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") - cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.EqualError(t, err, "certificate contained an unsafe network assignment outside the limitations of the signing ca: 10.0.1.0/15") - - // ip and mask are within the network - cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") - cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) - - // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.Nil(t, err) - _, err = caPool.VerifyCertificate(time.Now(), c) - assert.Nil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - err = c.VerifyPrivateKey(Curve_CURVE25519, priv) - assert.Nil(t, err) - - _, priv2 := x25519Keypair() - err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.NotNil(t, err) -} - -func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey) - assert.Nil(t, err) - - _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) - assert.Nil(t, err) - err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.NotNil(t, err) - - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) - err = c.VerifyPrivateKey(Curve_P256, priv) - assert.Nil(t, err) - - _, priv2 := p256Keypair() - err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.NotNil(t, err) -} - -func appendByteSlices(b ...[]byte) []byte { - retSlice := []byte{} - for _, v := range b { - retSlice = append(retSlice, v...) - } - return retSlice -} - -// Ensure that upgrading the protobuf library does not change how certificates -// are marshalled, since this would break signature verification -//TODO: since netip cant represent 255.0.255.0 netmask we can't verify the old certs are ok -//func TestMarshalingNebulaCertificateConsistency(t *testing.T) { -// before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) -// after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) -// pubKey := []byte("1234567890abcedfghij1234567890ab") -// -// nc := certificateV1{ -// details: detailsV1{ -// Name: "testing", -// Ips: []netip.Prefix{ -// mustParsePrefixUnmapped("10.1.1.1/24"), -// mustParsePrefixUnmapped("10.1.1.2/16"), -// //TODO: netip bad -// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// }, -// Subnets: []netip.Prefix{ -// //TODO: netip bad -// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, -// mustParsePrefixUnmapped("9.1.1.2/24"), -// mustParsePrefixUnmapped("9.1.1.3/16"), -// }, -// Groups: []string{"test-group1", "test-group2", "test-group3"}, -// NotBefore: before, -// NotAfter: after, -// PublicKey: pubKey, -// IsCA: false, -// Issuer: "1234567890abcedfghij1234567890ab", -// }, -// signature: []byte("1234567890abcedfghij1234567890ab"), -// } -// -// b, err := nc.Marshal() -// assert.Nil(t, err) -// //t.Log("Cert size:", len(b)) -// assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) -// -// b, err = proto.Marshal(nc.getRawDetails()) -// assert.Nil(t, err) -// //t.Log("Raw cert size:", len(b)) -// assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) -//} - -func TestNebulaCertificate_Copy(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) - assert.Nil(t, err) - cc := c.Copy() - - test.AssertDeepCopyEqual(t, c, cc) -} - -func TestUnmarshalNebulaCertificate(t *testing.T) { - // Test that we don't panic with an invalid certificate (#332) - data := []byte("\x98\x00\x00") - _, err := unmarshalCertificateV1(data, true) - assert.EqualError(t, err, "encoded Details was nil") -} - -func newTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - tbs := &TBSCertificate{ - Version: Version1, - Name: "test ca", - IsCA: true, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - } - - if len(ips) > 0 { - tbs.Networks = ips - } - - if len(subnets) > 0 { - tbs.UnsafeNetworks = subnets - } - - if len(groups) > 0 { - tbs.Groups = groups - } - - nc, err := tbs.Sign(nil, Curve_CURVE25519, priv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, priv, nil -} - -func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) - rawPriv := priv.D.FillBytes(make([]byte, 32)) - - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - tbs := &TBSCertificate{ - Version: Version1, - Name: "test ca", - IsCA: true, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - Curve: Curve_P256, - } - - if len(ips) > 0 { - tbs.Networks = ips - } - - if len(subnets) > 0 { - tbs.UnsafeNetworks = subnets - } - - if len(groups) > 0 { - tbs.Groups = groups - } - - nc, err := tbs.Sign(nil, Curve_P256, rawPriv) - if err != nil { - return nil, nil, nil, err - } - return nc, pub, rawPriv, nil -} - -func newTestCert(ca Certificate, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (Certificate, []byte, []byte, error) { - if before.IsZero() { - before = time.Now().Add(time.Second * -60).Round(time.Second) - } - if after.IsZero() { - after = time.Now().Add(time.Second * 60).Round(time.Second) - } - - if len(groups) == 0 { - groups = []string{"test-group1", "test-group2", "test-group3"} - } - - if len(ips) == 0 { - ips = []netip.Prefix{ - mustParsePrefixUnmapped("10.1.1.1/24"), - mustParsePrefixUnmapped("10.1.1.2/16"), - } - } - - if len(subnets) == 0 { - subnets = []netip.Prefix{ - mustParsePrefixUnmapped("9.1.1.2/24"), - mustParsePrefixUnmapped("9.1.1.3/16"), - } - } - - var pub, rawPriv []byte - - switch ca.Curve() { - case Curve_CURVE25519: - pub, rawPriv = x25519Keypair() - case Curve_P256: - pub, rawPriv = p256Keypair() - default: - return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Curve()) - } - - tbs := &TBSCertificate{ - Version: Version1, - Name: "testing", - Networks: ips, - UnsafeNetworks: subnets, - Groups: groups, - IsCA: false, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - Curve: ca.Curve(), - } - - nc, err := tbs.Sign(ca, ca.Curve(), key) - if err != nil { - return nil, nil, nil, err - } - - return nc, pub, rawPriv, nil -} - -func x25519Keypair() ([]byte, []byte) { - privkey := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, privkey); err != nil { - panic(err) - } - - pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) - if err != nil { - panic(err) - } - - return pubkey, privkey -} - -func p256Keypair() ([]byte, []byte) { - privkey, err := ecdh.P256().GenerateKey(rand.Reader) - if err != nil { - panic(err) - } - pubkey := privkey.PublicKey() - return pubkey.Bytes(), privkey.Bytes() -} - -func mustParsePrefixUnmapped(s string) netip.Prefix { - prefix := netip.MustParsePrefix(s) - return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) -} diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 165e40951..b807f8d21 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -6,19 +6,16 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" - "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" "encoding/json" "encoding/pem" "fmt" - "math/big" "net" "net/netip" "time" - "github.com/slackhq/nebula/pkclient" "golang.org/x/crypto/curve25519" "google.golang.org/protobuf/proto" ) @@ -31,71 +28,71 @@ type certificateV1 struct { } type detailsV1 struct { - Name string - Ips []netip.Prefix - Subnets []netip.Prefix - Groups []string - NotBefore time.Time - NotAfter time.Time - PublicKey []byte - IsCA bool - Issuer string - - Curve Curve + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + notBefore time.Time + notAfter time.Time + publicKey []byte + isCA bool + issuer string + + curve Curve } type m map[string]interface{} -func (nc *certificateV1) Version() Version { +func (c *certificateV1) Version() Version { return Version1 } -func (nc *certificateV1) Curve() Curve { - return nc.details.Curve +func (c *certificateV1) Curve() Curve { + return c.details.curve } -func (nc *certificateV1) Groups() []string { - return nc.details.Groups +func (c *certificateV1) Groups() []string { + return c.details.groups } -func (nc *certificateV1) IsCA() bool { - return nc.details.IsCA +func (c *certificateV1) IsCA() bool { + return c.details.isCA } -func (nc *certificateV1) Issuer() string { - return nc.details.Issuer +func (c *certificateV1) Issuer() string { + return c.details.issuer } -func (nc *certificateV1) Name() string { - return nc.details.Name +func (c *certificateV1) Name() string { + return c.details.name } -func (nc *certificateV1) Networks() []netip.Prefix { - return nc.details.Ips +func (c *certificateV1) Networks() []netip.Prefix { + return c.details.networks } -func (nc *certificateV1) NotAfter() time.Time { - return nc.details.NotAfter +func (c *certificateV1) NotAfter() time.Time { + return c.details.notAfter } -func (nc *certificateV1) NotBefore() time.Time { - return nc.details.NotBefore +func (c *certificateV1) NotBefore() time.Time { + return c.details.notBefore } -func (nc *certificateV1) PublicKey() []byte { - return nc.details.PublicKey +func (c *certificateV1) PublicKey() []byte { + return c.details.publicKey } -func (nc *certificateV1) Signature() []byte { - return nc.signature +func (c *certificateV1) Signature() []byte { + return c.signature } -func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { - return nc.details.Subnets +func (c *certificateV1) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks } -func (nc *certificateV1) Fingerprint() (string, error) { - b, err := nc.Marshal() +func (c *certificateV1) Fingerprint() (string, error) { + b, err := c.Marshal() if err != nil { return "", err } @@ -104,33 +101,33 @@ func (nc *certificateV1) Fingerprint() (string, error) { return hex.EncodeToString(sum[:]), nil } -func (nc *certificateV1) CheckSignature(key []byte) bool { - b, err := proto.Marshal(nc.getRawDetails()) +func (c *certificateV1) CheckSignature(key []byte) bool { + b, err := proto.Marshal(c.getRawDetails()) if err != nil { return false } - switch nc.details.Curve { + switch c.details.curve { case Curve_CURVE25519: - return ed25519.Verify(key, b, nc.signature) + return ed25519.Verify(key, b, c.signature) case Curve_P256: x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) default: return false } } -func (nc *certificateV1) Expired(t time.Time) bool { - return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) +func (c *certificateV1) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) } -func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != nc.details.Curve { +func (c *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.details.curve { return fmt.Errorf("curve in cert and private key supplied don't match") } - if nc.details.IsCA { + if c.details.isCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise @@ -138,7 +135,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } - if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + if !ed25519.PublicKey(c.details.publicKey).Equal(ed25519.PrivateKey(key).Public()) { return fmt.Errorf("public key in cert and private key supplied don't match") } case Curve_P256: @@ -147,7 +144,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("cannot parse private key as P256: %w", err) } pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.details.PublicKey) { + if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } default: @@ -173,7 +170,7 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { default: return fmt.Errorf("invalid curve: %s", curve) } - if !bytes.Equal(pub, nc.details.PublicKey) { + if !bytes.Equal(pub, c.details.publicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } @@ -181,173 +178,167 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { } // getRawDetails marshals the raw details into protobuf ready struct -func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { +func (c *certificateV1) getRawDetails() *RawNebulaCertificateDetails { rd := &RawNebulaCertificateDetails{ - Name: nc.details.Name, - Groups: nc.details.Groups, - NotBefore: nc.details.NotBefore.Unix(), - NotAfter: nc.details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Curve: nc.details.Curve, + Name: c.details.name, + Groups: c.details.groups, + NotBefore: c.details.notBefore.Unix(), + NotAfter: c.details.notAfter.Unix(), + PublicKey: make([]byte, len(c.details.publicKey)), + IsCA: c.details.isCA, + Curve: c.details.curve, } - for _, ipNet := range nc.details.Ips { + for _, ipNet := range c.details.networks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) } - for _, ipNet := range nc.details.Subnets { + for _, ipNet := range c.details.unsafeNetworks { mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) } - copy(rd.PublicKey, nc.details.PublicKey[:]) + copy(rd.PublicKey, c.details.publicKey[:]) // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) + rd.Issuer, _ = hex.DecodeString(c.details.issuer) return rd } -func (nc *certificateV1) String() string { - if nc == nil { - return "Certificate {}\n" - } - - s := "NebulaCertificate {\n" - s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name) - - if len(nc.details.Ips) > 0 { - s += "\t\tIps: [\n" - for _, ip := range nc.details.Ips { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tIps: []\n" - } - - if len(nc.details.Subnets) > 0 { - s += "\t\tSubnets: [\n" - for _, ip := range nc.details.Subnets { - s += fmt.Sprintf("\t\t\t%v\n", ip.String()) - } - s += "\t\t]\n" - } else { - s += "\t\tSubnets: []\n" - } - - if len(nc.details.Groups) > 0 { - s += "\t\tGroups: [\n" - for _, g := range nc.details.Groups { - s += fmt.Sprintf("\t\t\t\"%v\"\n", g) - } - s += "\t\t]\n" - } else { - s += "\t\tGroups: []\n" - } - - s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve) - s += "\t}\n" - fp, err := nc.Fingerprint() - if err == nil { - s += fmt.Sprintf("\tFingerprint: %s\n", fp) +func (c *certificateV1) String() string { + b, err := json.MarshalIndent(c.marshalJSON(), "", "\t") + if err != nil { + return fmt.Sprintf("", err) } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature()) - s += "}" - - return s + return string(b) } -func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { - pubKey := nc.details.PublicKey - nc.details.PublicKey = nil - rawCertNoKey, err := nc.Marshal() +func (c *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := c.details.publicKey + c.details.publicKey = nil + rawCertNoKey, err := c.Marshal() if err != nil { return nil, err } - nc.details.PublicKey = pubKey + c.details.publicKey = pubKey return rawCertNoKey, nil } -func (nc *certificateV1) Marshal() ([]byte, error) { +func (c *certificateV1) Marshal() ([]byte, error) { rc := RawNebulaCertificate{ - Details: nc.getRawDetails(), - Signature: nc.signature, + Details: c.getRawDetails(), + Signature: c.signature, } return proto.Marshal(&rc) } -func (nc *certificateV1) MarshalPEM() ([]byte, error) { - b, err := nc.Marshal() +func (c *certificateV1) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() if err != nil { return nil, err } return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil } -func (nc *certificateV1) MarshalJSON() ([]byte, error) { - fp, _ := nc.Fingerprint() - jc := m{ +func (c *certificateV1) MarshalJSON() ([]byte, error) { + return json.Marshal(c.marshalJSON()) +} + +func (c *certificateV1) marshalJSON() m { + fp, _ := c.Fingerprint() + return m{ + "version": Version1, "details": m{ - "name": nc.details.Name, - "ips": nc.details.Ips, - "subnets": nc.details.Subnets, - "groups": nc.details.Groups, - "notBefore": nc.details.NotBefore, - "notAfter": nc.details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.details.PublicKey), - "isCa": nc.details.IsCA, - "issuer": nc.details.Issuer, - "curve": nc.details.Curve.String(), + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "publicKey": fmt.Sprintf("%x", c.details.publicKey), + "isCa": c.details.isCA, + "issuer": c.details.issuer, + "curve": c.details.curve.String(), }, "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature()), + "signature": fmt.Sprintf("%x", c.Signature()), } - return json.Marshal(jc) } -func (nc *certificateV1) Copy() Certificate { - c := &certificateV1{ +func (c *certificateV1) Copy() Certificate { + nc := &certificateV1{ details: detailsV1{ - Name: nc.details.Name, - Groups: make([]string, len(nc.details.Groups)), - Ips: make([]netip.Prefix, len(nc.details.Ips)), - Subnets: make([]netip.Prefix, len(nc.details.Subnets)), - NotBefore: nc.details.NotBefore, - NotAfter: nc.details.NotAfter, - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Issuer: nc.details.Issuer, + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + publicKey: make([]byte, len(c.details.publicKey)), + isCA: c.details.isCA, + issuer: c.details.issuer, + curve: c.details.curve, }, - signature: make([]byte, len(nc.signature)), + signature: make([]byte, len(c.signature)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) } - copy(c.signature, nc.signature) - copy(c.details.Groups, nc.details.Groups) - copy(c.details.PublicKey, nc.details.PublicKey) + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } - for i, p := range nc.details.Ips { - c.details.Ips[i] = p + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) } - for i, p := range nc.details.Subnets { - c.details.Subnets[i] = p + copy(nc.signature, c.signature) + copy(nc.details.publicKey, c.details.publicKey) + + return nc +} + +func (c *certificateV1) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV1{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + publicKey: t.PublicKey, + isCA: t.IsCA, + curve: t.Curve, + issuer: t.issuer, + } + + return nil +} + +func (c *certificateV1) marshalForSigning() ([]byte, error) { + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err } + return b, nil +} - return c +func (c *certificateV1) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil } // unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert -func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { +// if the publicKey is provided here then it is not required to be present in `b` +func unmarshalCertificateV1(b []byte, publicKey []byte) (*certificateV1, error) { if len(b) == 0 { return nil, fmt.Errorf("nil byte array") } @@ -371,27 +362,28 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err nc := certificateV1{ details: detailsV1{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), - Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - Curve: rc.Details.Curve, + name: rc.Details.Name, + groups: make([]string, len(rc.Details.Groups)), + networks: make([]netip.Prefix, len(rc.Details.Ips)/2), + unsafeNetworks: make([]netip.Prefix, len(rc.Details.Subnets)/2), + notBefore: time.Unix(rc.Details.NotBefore, 0), + notAfter: time.Unix(rc.Details.NotAfter, 0), + publicKey: make([]byte, len(rc.Details.PublicKey)), + isCA: rc.Details.IsCA, + curve: rc.Details.Curve, }, signature: make([]byte, len(rc.Signature)), } copy(nc.signature, rc.Signature) - copy(nc.details.Groups, rc.Details.Groups) - nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) + copy(nc.details.groups, rc.Details.Groups) + nc.details.issuer = hex.EncodeToString(rc.Details.Issuer) - if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { - return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) + if len(publicKey) > 0 { + nc.details.publicKey = publicKey } - copy(nc.details.PublicKey, rc.Details.PublicKey) + + copy(nc.details.publicKey, rc.Details.PublicKey) var ip netip.Addr for i, rawIp := range rc.Details.Ips { @@ -399,7 +391,7 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) + nc.details.networks[i/2] = netip.PrefixFrom(ip, ones) } } @@ -408,69 +400,13 @@ func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, err ip = int2addr(rawIp) } else { ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) + nc.details.unsafeNetworks[i/2] = netip.PrefixFrom(ip, ones) } } return &nc, nil } -func signV1(t *TBSCertificate, curve Curve, key []byte, client *pkclient.PKClient) (*certificateV1, error) { - c := &certificateV1{ - details: detailsV1{ - Name: t.Name, - Ips: t.Networks, - Subnets: t.UnsafeNetworks, - Groups: t.Groups, - NotBefore: t.NotBefore, - NotAfter: t.NotAfter, - PublicKey: t.PublicKey, - IsCA: t.IsCA, - Curve: t.Curve, - Issuer: t.issuer, - }, - } - b, err := proto.Marshal(c.getRawDetails()) - if err != nil { - return nil, err - } - - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - if client != nil { - sig, err = client.SignASN1(b) - } else { - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return nil, err - } - } - default: - return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) - } - - c.signature = sig - return c, nil -} - func ip2int(ip []byte) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go new file mode 100644 index 000000000..8c3fe930b --- /dev/null +++ b/cert/cert_v1_test.go @@ -0,0 +1,218 @@ +package cert + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" +) + +func TestCertificateV1_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + assert.Nil(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + + assert.Equal(t, nc.Version(), Version1) + assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV1_Expired(t *testing.T) { + nc := certificateV1{ + details: detailsV1{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV1_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", + string(b), + ) +} + +func TestCertificateV1_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + assert.Nil(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.NotNil(t, err) +} + +func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + assert.Nil(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.NotNil(t, err) +} + +// Ensure that upgrading the protobuf library does not change how certificates +// are marshalled, since this would break signature verification +func TestMarshalingCertificateV1Consistency(t *testing.T) { + before := time.Date(1970, time.January, 1, 1, 1, 1, 1, time.UTC) + after := time.Date(9999, time.January, 1, 1, 1, 1, 1, time.UTC) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV1{ + details: detailsV1{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + publicKey: pubKey, + isCA: false, + issuer: "1234567890abcedfghij1234567890ab", + }, + signature: []byte("1234567890abcedfghij1234567890ab"), + } + + b, err := nc.Marshal() + require.Nil(t, err) + assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) + + b, err = proto.Marshal(nc.getRawDetails()) + assert.Nil(t, err) + assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +} + +func TestCertificateV1_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV1(t *testing.T) { + // Test that we don't panic with an invalid certificate (#332) + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV1(data, nil) + assert.EqualError(t, err, "encoded Details was nil") +} + +func appendByteSlices(b ...[]byte) []byte { + retSlice := []byte{} + for _, v := range b { + retSlice = append(retSlice, v...) + } + return retSlice +} + +func mustParsePrefixUnmapped(s string) netip.Prefix { + prefix := netip.MustParsePrefix(s) + return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) +} diff --git a/cert/cert_v2.asn1 b/cert/cert_v2.asn1 new file mode 100644 index 000000000..f863133f7 --- /dev/null +++ b/cert/cert_v2.asn1 @@ -0,0 +1,37 @@ +Nebula DEFINITIONS AUTOMATIC TAGS ::= BEGIN + +Name ::= UTF8String (SIZE (1..253)) +Time ::= INTEGER (0..18446744073709551615) -- Seconds since unix epoch, uint64 maximum +Network ::= OCTET STRING (SIZE (5,17)) -- IP addresses are 4 or 16 bytes + 1 byte for the prefix length +Curve ::= ENUMERATED { + curve25519 (0), + p256 (1) +} + +-- The maximum size of a certificate must not exceed 65536 bytes +Certificate ::= SEQUENCE { + details OCTET STRING, + curve Curve DEFAULT curve25519, + publicKey OCTET STRING, + -- signature(details + curve + publicKey) using the appropriate method for curve + signature OCTET STRING +} + +Details ::= SEQUENCE { + name Name, + + -- At least 1 ipv4 or ipv6 address must be present if isCA is false + networks SEQUENCE OF Network OPTIONAL, + unsafeNetworks SEQUENCE OF Network OPTIONAL, + groups SEQUENCE OF Name OPTIONAL, + isCA BOOLEAN DEFAULT false, + notBefore Time, + notAfter Time, + + -- issuer is only required if isCA is false, if isCA is true then it must not be present + issuer OCTET STRING OPTIONAL, + ... + -- New fields can be added below here +} + +END \ No newline at end of file diff --git a/cert/cert_v2.go b/cert/cert_v2.go new file mode 100644 index 000000000..dce929684 --- /dev/null +++ b/cert/cert_v2.go @@ -0,0 +1,655 @@ +package cert + +import ( + "bytes" + "crypto/ecdh" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "net/netip" + "slices" + "time" + + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "golang.org/x/crypto/curve25519" +) + +const ( + classConstructed = 0x20 + classContextSpecific = 0x80 + + TagCertDetails = 0 | classConstructed | classContextSpecific + TagCertCurve = 1 | classContextSpecific + TagCertPublicKey = 2 | classContextSpecific + TagCertSignature = 3 | classContextSpecific + + TagDetailsName = 0 | classContextSpecific + TagDetailsNetworks = 1 | classConstructed | classContextSpecific + TagDetailsUnsafeNetworks = 2 | classConstructed | classContextSpecific + TagDetailsGroups = 3 | classConstructed | classContextSpecific + TagDetailsIsCA = 4 | classContextSpecific + TagDetailsNotBefore = 5 | classContextSpecific + TagDetailsNotAfter = 6 | classContextSpecific + TagDetailsIssuer = 7 | classContextSpecific +) + +const ( + // MaxCertificateSize is the maximum length a valid certificate can be + MaxCertificateSize = 65536 + + // MaxNameLength is limited to a maximum realistic DNS domain name to help facilitate DNS systems + MaxNameLength = 253 + + // MaxNetworkLength is the maximum length a network value can be. + // 16 bytes for an ipv6 address + 1 byte for the prefix length + MaxNetworkLength = 17 +) + +type certificateV2 struct { + details detailsV2 + + // RawDetails contains the entire asn.1 DER encoded Details struct + // This is to benefit forwards compatibility in signature checking. + // signature(RawDetails + Curve + PublicKey) == Signature + rawDetails []byte + curve Curve + publicKey []byte + signature []byte +} + +type detailsV2 struct { + name string + networks []netip.Prefix + unsafeNetworks []netip.Prefix + groups []string + isCA bool + notBefore time.Time + notAfter time.Time + issuer string +} + +func (c *certificateV2) Version() Version { + return Version2 +} + +func (c *certificateV2) Curve() Curve { + return c.curve +} + +func (c *certificateV2) Groups() []string { + return c.details.groups +} + +func (c *certificateV2) IsCA() bool { + return c.details.isCA +} + +func (c *certificateV2) Issuer() string { + return c.details.issuer +} + +func (c *certificateV2) Name() string { + return c.details.name +} + +func (c *certificateV2) Networks() []netip.Prefix { + return c.details.networks +} + +func (c *certificateV2) NotAfter() time.Time { + return c.details.notAfter +} + +func (c *certificateV2) NotBefore() time.Time { + return c.details.notBefore +} + +func (c *certificateV2) PublicKey() []byte { + return c.publicKey +} + +func (c *certificateV2) Signature() []byte { + return c.signature +} + +func (c *certificateV2) UnsafeNetworks() []netip.Prefix { + return c.details.unsafeNetworks +} + +func (c *certificateV2) Fingerprint() (string, error) { + if len(c.rawDetails) == 0 { + return "", ErrMissingDetails + } + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)+len(c.signature)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + copy(b[len(c.rawDetails)+1+len(c.publicKey):], c.signature) + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil +} + +func (c *certificateV2) CheckSignature(key []byte) bool { + if len(c.rawDetails) == 0 { + return false + } + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + + switch c.curve { + case Curve_CURVE25519: + return ed25519.Verify(key, b, c.signature) + case Curve_P256: + x, y := elliptic.Unmarshal(elliptic.P256(), key) + pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} + hashed := sha256.Sum256(b) + return ecdsa.VerifyASN1(pubKey, hashed[:], c.signature) + default: + return false + } +} + +func (c *certificateV2) Expired(t time.Time) bool { + return c.details.notBefore.After(t) || c.details.notAfter.Before(t) +} + +func (c *certificateV2) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != c.curve { + return ErrPublicPrivateCurveMismatch + } + if c.details.isCA { + switch curve { + case Curve_CURVE25519: + // the call to PublicKey below will panic slice bounds out of range otherwise + if len(key) != ed25519.PrivateKeySize { + return ErrInvalidPrivateKey + } + + if !ed25519.PublicKey(c.publicKey).Equal(ed25519.PrivateKey(key).Public()) { + return ErrPublicPrivateKeyMismatch + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub := privkey.PublicKey().Bytes() + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + default: + return fmt.Errorf("invalid curve: %s", curve) + } + return nil + } + + var pub []byte + switch curve { + case Curve_CURVE25519: + var err error + pub, err = curve25519.X25519(key, curve25519.Basepoint) + if err != nil { + return ErrInvalidPrivateKey + } + case Curve_P256: + privkey, err := ecdh.P256().NewPrivateKey(key) + if err != nil { + return ErrInvalidPrivateKey + } + pub = privkey.PublicKey().Bytes() + default: + return fmt.Errorf("invalid curve: %s", curve) + } + if !bytes.Equal(pub, c.publicKey) { + return ErrPublicPrivateKeyMismatch + } + + return nil +} + +func (c *certificateV2) String() string { + mb, err := c.marshalJSON() + if err != nil { + return fmt.Sprintf("", err) + } + + b, err := json.MarshalIndent(mb, "", "\t") + if err != nil { + return fmt.Sprintf("", err) + } + return string(b) +} + +func (c *certificateV2) MarshalForHandshakes() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Skipping the curve and public key since those come across in a different part of the handshake + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) Marshal() ([]byte, error) { + if c.rawDetails == nil { + return nil, ErrEmptyRawDetails + } + var b cryptobyte.Builder + // Outermost certificate + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + + // Add the cert details which is already marshalled + b.AddBytes(c.rawDetails) + + // Add the curve only if its not the default value + if c.curve != Curve_CURVE25519 { + b.AddASN1(TagCertCurve, func(b *cryptobyte.Builder) { + b.AddBytes([]byte{byte(c.curve)}) + }) + } + + // Add the public key if it is not empty + if c.publicKey != nil { + b.AddASN1(TagCertPublicKey, func(b *cryptobyte.Builder) { + b.AddBytes(c.publicKey) + }) + } + + // Add the signature + b.AddASN1(TagCertSignature, func(b *cryptobyte.Builder) { + b.AddBytes(c.signature) + }) + }) + + return b.Bytes() +} + +func (c *certificateV2) MarshalPEM() ([]byte, error) { + b, err := c.Marshal() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: CertificateV2Banner, Bytes: b}), nil +} + +func (c *certificateV2) MarshalJSON() ([]byte, error) { + b, err := c.marshalJSON() + if err != nil { + return nil, err + } + return json.Marshal(b) +} + +func (c *certificateV2) marshalJSON() (m, error) { + fp, err := c.Fingerprint() + if err != nil { + return nil, err + } + + return m{ + "details": m{ + "name": c.details.name, + "networks": c.details.networks, + "unsafeNetworks": c.details.unsafeNetworks, + "groups": c.details.groups, + "notBefore": c.details.notBefore, + "notAfter": c.details.notAfter, + "isCa": c.details.isCA, + "issuer": c.details.issuer, + }, + "version": Version2, + "publicKey": fmt.Sprintf("%x", c.publicKey), + "curve": c.curve.String(), + "fingerprint": fp, + "signature": fmt.Sprintf("%x", c.Signature()), + }, nil +} + +func (c *certificateV2) Copy() Certificate { + nc := &certificateV2{ + details: detailsV2{ + name: c.details.name, + notBefore: c.details.notBefore, + notAfter: c.details.notAfter, + isCA: c.details.isCA, + issuer: c.details.issuer, + }, + curve: c.curve, + publicKey: make([]byte, len(c.publicKey)), + signature: make([]byte, len(c.signature)), + rawDetails: make([]byte, len(c.rawDetails)), + } + + if c.details.groups != nil { + nc.details.groups = make([]string, len(c.details.groups)) + copy(nc.details.groups, c.details.groups) + } + + if c.details.networks != nil { + nc.details.networks = make([]netip.Prefix, len(c.details.networks)) + copy(nc.details.networks, c.details.networks) + } + + if c.details.unsafeNetworks != nil { + nc.details.unsafeNetworks = make([]netip.Prefix, len(c.details.unsafeNetworks)) + copy(nc.details.unsafeNetworks, c.details.unsafeNetworks) + } + + copy(nc.rawDetails, c.rawDetails) + copy(nc.signature, c.signature) + copy(nc.publicKey, c.publicKey) + + return nc +} + +func (c *certificateV2) fromTBSCertificate(t *TBSCertificate) error { + c.details = detailsV2{ + name: t.Name, + networks: t.Networks, + unsafeNetworks: t.UnsafeNetworks, + groups: t.Groups, + isCA: t.IsCA, + notBefore: t.NotBefore, + notAfter: t.NotAfter, + issuer: t.issuer, + } + c.curve = t.Curve + c.publicKey = t.PublicKey + return nil +} + +func (c *certificateV2) marshalForSigning() ([]byte, error) { + d, err := c.details.Marshal() + if err != nil { + return nil, fmt.Errorf("marshalling certificate details failed: %w", err) + } + c.rawDetails = d + + b := make([]byte, len(c.rawDetails)+1+len(c.publicKey)) + copy(b, c.rawDetails) + b[len(c.rawDetails)] = byte(c.curve) + copy(b[len(c.rawDetails)+1:], c.publicKey) + return b, nil +} + +func (c *certificateV2) setSignature(b []byte) error { + if len(b) == 0 { + return ErrEmptySignature + } + c.signature = b + return nil +} + +func (d *detailsV2) Marshal() ([]byte, error) { + var b cryptobyte.Builder + var err error + + // Details are a structure + b.AddASN1(TagCertDetails, func(b *cryptobyte.Builder) { + + // Add the name + b.AddASN1(TagDetailsName, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(d.name)) + }) + + // Add the networks if any exist + if len(d.networks) > 0 { + b.AddASN1(TagDetailsNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.networks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add the unsafe networks if any exist + if len(d.unsafeNetworks) > 0 { + b.AddASN1(TagDetailsUnsafeNetworks, func(b *cryptobyte.Builder) { + for _, n := range d.unsafeNetworks { + sb, innerErr := n.MarshalBinary() + if innerErr != nil { + // MarshalBinary never returns an error + err = fmt.Errorf("unable to marshal unsafe network: %w", innerErr) + return + } + b.AddASN1OctetString(sb) + } + }) + } + + // Add groups if any exist + if len(d.groups) > 0 { + b.AddASN1(TagDetailsGroups, func(b *cryptobyte.Builder) { + for _, group := range d.groups { + b.AddASN1(asn1.UTF8String, func(b *cryptobyte.Builder) { + b.AddBytes([]byte(group)) + }) + } + }) + } + + // Add IsCA only if true + if d.isCA { + b.AddASN1(TagDetailsIsCA, func(b *cryptobyte.Builder) { + b.AddUint8(0xff) + }) + } + + // Add not before + b.AddASN1Int64WithTag(d.notBefore.Unix(), TagDetailsNotBefore) + + // Add not after + b.AddASN1Int64WithTag(d.notAfter.Unix(), TagDetailsNotAfter) + + // Add the issuer if present + if d.issuer != "" { + issuerBytes, innerErr := hex.DecodeString(d.issuer) + if innerErr != nil { + err = fmt.Errorf("failed to decode issuer: %w", innerErr) + return + } + b.AddASN1(TagDetailsIssuer, func(b *cryptobyte.Builder) { + b.AddBytes(issuerBytes) + }) + } + }) + + if err != nil { + return nil, err + } + + return b.Bytes() +} + +func unmarshalCertificateV2(b []byte, publicKey []byte, curve Curve) (*certificateV2, error) { + l := len(b) + if l == 0 || l > MaxCertificateSize { + return nil, ErrBadFormat + } + + input := cryptobyte.String(b) + // Open the envelope + if !input.ReadASN1(&input, asn1.SEQUENCE) || input.Empty() { + return nil, ErrBadFormat + } + + // Grab the cert details, we need to preserve the tag and length + var rawDetails cryptobyte.String + if !input.ReadASN1Element(&rawDetails, TagCertDetails) || rawDetails.Empty() { + return nil, ErrBadFormat + } + + //Maybe grab the curve + var rawCurve byte + if !readOptionalASN1Byte(&input, &rawCurve, TagCertCurve, byte(curve)) { + return nil, ErrBadFormat + } + curve = Curve(rawCurve) + + // Maybe grab the public key + var rawPublicKey cryptobyte.String + if len(publicKey) > 0 { + rawPublicKey = publicKey + } else if !input.ReadOptionalASN1(&rawPublicKey, nil, TagCertPublicKey) { + return nil, ErrBadFormat + } + + if len(rawPublicKey) == 0 { + return nil, ErrBadFormat + } + + // Grab the signature + var rawSignature cryptobyte.String + if !input.ReadASN1(&rawSignature, TagCertSignature) || rawSignature.Empty() { + return nil, ErrBadFormat + } + + // Finally unmarshal the details + details, err := unmarshalDetails(rawDetails) + if err != nil { + return nil, err + } + + return &certificateV2{ + details: details, + rawDetails: rawDetails, + curve: curve, + publicKey: rawPublicKey, + signature: rawSignature, + }, nil +} + +func unmarshalDetails(b cryptobyte.String) (detailsV2, error) { + // Open the envelope + if !b.ReadASN1(&b, TagCertDetails) || b.Empty() { + return detailsV2{}, ErrBadFormat + } + + // Read the name + var name cryptobyte.String + if !b.ReadASN1(&name, TagDetailsName) || name.Empty() || len(name) > MaxNameLength { + return detailsV2{}, ErrBadFormat + } + + // Read the network addresses + var subString cryptobyte.String + var found bool + + if !b.ReadOptionalASN1(&subString, &found, TagDetailsNetworks) { + return detailsV2{}, ErrBadFormat + } + + var networks []netip.Prefix + var val cryptobyte.String + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + networks = append(networks, n) + } + } + + // Read out any unsafe networks + if !b.ReadOptionalASN1(&subString, &found, TagDetailsUnsafeNetworks) { + return detailsV2{}, ErrBadFormat + } + + var unsafeNetworks []netip.Prefix + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.OCTET_STRING) || val.Empty() || len(val) > MaxNetworkLength { + return detailsV2{}, ErrBadFormat + } + + var n netip.Prefix + if err := n.UnmarshalBinary(val); err != nil { + return detailsV2{}, ErrBadFormat + } + unsafeNetworks = append(unsafeNetworks, n) + } + } + + // Read out any groups + if !b.ReadOptionalASN1(&subString, &found, TagDetailsGroups) { + return detailsV2{}, ErrBadFormat + } + + var groups []string + if found { + for !subString.Empty() { + if !subString.ReadASN1(&val, asn1.UTF8String) || val.Empty() { + return detailsV2{}, ErrBadFormat + } + groups = append(groups, string(val)) + } + } + + // Read out IsCA + var isCa bool + if !readOptionalASN1Boolean(&b, &isCa, TagDetailsIsCA, false) { + return detailsV2{}, ErrBadFormat + } + + // Read not before and not after + var notBefore int64 + if !b.ReadASN1Int64WithTag(¬Before, TagDetailsNotBefore) { + return detailsV2{}, ErrBadFormat + } + + var notAfter int64 + if !b.ReadASN1Int64WithTag(¬After, TagDetailsNotAfter) { + return detailsV2{}, ErrBadFormat + } + + // Read issuer + var issuer cryptobyte.String + if !b.ReadOptionalASN1(&issuer, nil, TagDetailsIssuer) { + return detailsV2{}, ErrBadFormat + } + + slices.SortFunc(networks, comparePrefix) + slices.SortFunc(unsafeNetworks, comparePrefix) + + return detailsV2{ + name: string(name), + networks: networks, + unsafeNetworks: unsafeNetworks, + groups: groups, + isCA: isCa, + notBefore: time.Unix(notBefore, 0), + notAfter: time.Unix(notAfter, 0), + issuer: hex.EncodeToString(issuer), + }, nil +} diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go new file mode 100644 index 000000000..3afbcab14 --- /dev/null +++ b/cert/cert_v2_test.go @@ -0,0 +1,267 @@ +package cert + +import ( + "crypto/ed25519" + "crypto/rand" + "encoding/hex" + "net/netip" + "slices" + "testing" + "time" + + "github.com/slackhq/nebula/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCertificateV2_Marshal(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + db, err := nc.details.Marshal() + require.NoError(t, err) + nc.rawDetails = db + + b, err := nc.Marshal() + require.Nil(t, err) + //t.Log("Cert size:", len(b)) + + nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) + assert.Nil(t, err) + + assert.Equal(t, nc.Version(), Version2) + assert.Equal(t, nc.Curve(), Curve_CURVE25519) + assert.Equal(t, nc.Signature(), nc2.Signature()) + assert.Equal(t, nc.Name(), nc2.Name()) + assert.Equal(t, nc.NotBefore(), nc2.NotBefore()) + assert.Equal(t, nc.NotAfter(), nc2.NotAfter()) + assert.Equal(t, nc.PublicKey(), nc2.PublicKey()) + assert.Equal(t, nc.IsCA(), nc2.IsCA()) + assert.Equal(t, nc.Issuer(), nc2.Issuer()) + + // unmarshalling will sort networks and unsafeNetworks, we need to do the same + // but first make sure it fails + assert.NotEqual(t, nc.Networks(), nc2.Networks()) + assert.NotEqual(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + slices.SortFunc(nc.details.networks, comparePrefix) + slices.SortFunc(nc.details.unsafeNetworks, comparePrefix) + + assert.Equal(t, nc.Networks(), nc2.Networks()) + assert.Equal(t, nc.UnsafeNetworks(), nc2.UnsafeNetworks()) + + assert.Equal(t, nc.Groups(), nc2.Groups()) +} + +func TestCertificateV2_Expired(t *testing.T) { + nc := certificateV2{ + details: detailsV2{ + notBefore: time.Now().Add(time.Second * -60).Round(time.Second), + notAfter: time.Now().Add(time.Second * 60).Round(time.Second), + }, + } + + assert.True(t, nc.Expired(time.Now().Add(time.Hour))) + assert.True(t, nc.Expired(time.Now().Add(-time.Hour))) + assert.False(t, nc.Expired(time.Now())) +} + +func TestCertificateV2_MarshalJSON(t *testing.T) { + time.Local = time.UTC + pubKey := []byte("1234567890abcedf1234567890abcedf") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), + notAfter: time.Date(1, 0, 0, 2, 0, 0, 0, time.UTC), + isCA: false, + issuer: "1234567890abcedf1234567890abcedf", + }, + publicKey: pubKey, + signature: []byte("1234567890abcedf1234567890abcedf1234567890abcedf1234567890abcedf"), + } + + b, err := nc.MarshalJSON() + assert.ErrorIs(t, err, ErrMissingDetails) + + rd, err := nc.details.Marshal() + assert.NoError(t, err) + + nc.rawDetails = rd + b, err = nc.MarshalJSON() + assert.Nil(t, err) + assert.Equal( + t, + "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", + string(b), + ) +} + +func TestCertificateV2_VerifyPrivateKey(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + _, caKey2, err := ed25519.GenerateKey(rand.Reader) + require.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) + assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_CURVE25519, curve) + err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) + assert.Nil(t, err) + + _, priv2 := X25519Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) + assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + + err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + ac, ok := c.(*certificateV2) + require.True(t, ok) + ac.curve = Curve(99) + err = c.VerifyPrivateKey(Curve(99), priv2) + assert.EqualError(t, err, "invalid curve: 99") + + ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) + assert.Nil(t, err) + + err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) + + err = c.VerifyPrivateKey(Curve_P256, priv[:16]) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + err = c.VerifyPrivateKey(Curve_P256, priv) + assert.ErrorIs(t, err, ErrInvalidPrivateKey) + + aCa, ok := ca2.(*certificateV2) + require.True(t, ok) + aCa.curve = Curve(99) + err = aCa.VerifyPrivateKey(Curve(99), priv2) + assert.EqualError(t, err, "invalid curve: 99") + +} + +func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + err := ca.VerifyPrivateKey(Curve_P256, caKey) + assert.Nil(t, err) + + _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) + assert.Nil(t, err) + err = ca.VerifyPrivateKey(Curve_P256, caKey2) + assert.NotNil(t, err) + + c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) + rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) + assert.NoError(t, err) + assert.Empty(t, b) + assert.Equal(t, Curve_P256, curve) + err = c.VerifyPrivateKey(Curve_P256, rawPriv) + assert.Nil(t, err) + + _, priv2 := P256Keypair() + err = c.VerifyPrivateKey(Curve_P256, priv2) + assert.NotNil(t, err) +} + +func TestCertificateV2_Copy(t *testing.T) { + ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) + c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) + cc := c.Copy() + test.AssertDeepCopyEqual(t, c, cc) +} + +func TestUnmarshalCertificateV2(t *testing.T) { + data := []byte("\x98\x00\x00") + _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) + assert.EqualError(t, err, "bad wire format") +} + +func TestCertificateV2_marshalForSigningStability(t *testing.T) { + before := time.Date(1996, time.May, 5, 0, 0, 0, 0, time.UTC) + after := before.Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + nc := certificateV2{ + details: detailsV2{ + name: "testing", + networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.2/16"), + mustParsePrefixUnmapped("10.1.1.1/24"), + }, + unsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.3/16"), + mustParsePrefixUnmapped("9.1.1.2/24"), + }, + groups: []string{"test-group1", "test-group2", "test-group3"}, + notBefore: before, + notAfter: after, + isCA: false, + issuer: "1234567890abcdef1234567890abcdef", + }, + signature: []byte("1234567890abcdef1234567890abcdef"), + publicKey: pubKey, + } + + const expectedRawDetailsStr = "a070800774657374696e67a10e04050a0101021004050a01010118a20e0405090101031004050901010218a3270c0b746573742d67726f7570310c0b746573742d67726f7570320c0b746573742d67726f7570338504318bef808604318befbc87101234567890abcdef1234567890abcdef" + expectedRawDetails, err := hex.DecodeString(expectedRawDetailsStr) + require.NoError(t, err) + + db, err := nc.details.Marshal() + require.NoError(t, err) + assert.Equal(t, expectedRawDetails, db) + + expectedForSigning, err := hex.DecodeString(expectedRawDetailsStr + "00313233343536373839306162636564666768696a313233343536373839306162") + b, err := nc.marshalForSigning() + require.NoError(t, err) + assert.Equal(t, expectedForSigning, b) +} diff --git a/cert/errors.go b/cert/errors.go index da0d1be3f..60273a99d 100644 --- a/cert/errors.go +++ b/cert/errors.go @@ -5,16 +5,19 @@ import ( ) var ( - ErrBadFormat = errors.New("bad wire format") - ErrRootExpired = errors.New("root certificate is expired") - ErrExpired = errors.New("certificate is expired") - ErrNotCA = errors.New("certificate is not a CA") - ErrNotSelfSigned = errors.New("certificate is not self-signed") - ErrBlockListed = errors.New("certificate is in the block list") - ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") - ErrSignatureMismatch = errors.New("certificate signature did not match") - ErrInvalidPublicKeyLength = errors.New("invalid public key length") - ErrInvalidPrivateKeyLength = errors.New("invalid private key length") + ErrBadFormat = errors.New("bad wire format") + ErrRootExpired = errors.New("root certificate is expired") + ErrExpired = errors.New("certificate is expired") + ErrNotCA = errors.New("certificate is not a CA") + ErrNotSelfSigned = errors.New("certificate is not self-signed") + ErrBlockListed = errors.New("certificate is in the block list") + ErrFingerprintMismatch = errors.New("certificate fingerprint did not match") + ErrSignatureMismatch = errors.New("certificate signature did not match") + ErrInvalidPublicKey = errors.New("invalid public key") + ErrInvalidPrivateKey = errors.New("invalid private key") + ErrPublicPrivateCurveMismatch = errors.New("public key does not match private key curve") + ErrPublicPrivateKeyMismatch = errors.New("public key and private key are not a pair") + ErrCaNotFound = errors.New("could not find ca for the certificate") ErrPrivateKeyEncrypted = errors.New("private key must be decrypted") @@ -24,4 +27,11 @@ var ( ErrInvalidPEMX25519PrivateKeyBanner = errors.New("bytes did not contain a proper X25519 private key banner") ErrInvalidPEMEd25519PublicKeyBanner = errors.New("bytes did not contain a proper Ed25519 public key banner") ErrInvalidPEMEd25519PrivateKeyBanner = errors.New("bytes did not contain a proper Ed25519 private key banner") + + ErrNoPeerStaticKey = errors.New("no peer static key was present") + ErrNoPayload = errors.New("provided payload was empty") + + ErrMissingDetails = errors.New("certificate did not contain details") + ErrEmptySignature = errors.New("empty signature") + ErrEmptyRawDetails = errors.New("empty rawDetails not allowed") ) diff --git a/cert/helper_test.go b/cert/helper_test.go new file mode 100644 index 000000000..05142dd54 --- /dev/null +++ b/cert/helper_test.go @@ -0,0 +1,137 @@ +package cert + +import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "io" + "net/netip" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/ed25519" +) + +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version Version, curve Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + t := &TBSCertificate{ + Curve: curve, + Version: version, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, + } + + c, err := t.Sign(nil, curve, priv) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, priv, pem +} + +// NewTestCert will generate a signed certificate with the provided details. +// Expiry times are defaulted if you do not pass them in +func NewTestCert(v Version, curve Curve, ca Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (Certificate, []byte, []byte, []byte) { + if before.IsZero() { + before = time.Now().Add(time.Second * -60).Round(time.Second) + } + + if after.IsZero() { + after = time.Now().Add(time.Second * 60).Round(time.Second) + } + + var pub, priv []byte + switch curve { + case Curve_CURVE25519: + pub, priv = X25519Keypair() + case Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + + nc := &TBSCertificate{ + Version: v, + Curve: curve, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) + if err != nil { + panic(err) + } + + pem, err := c.MarshalPEM() + if err != nil { + panic(err) + } + + return c, pub, MarshalPrivateKeyToPEM(curve, priv), pem +} + +func X25519Keypair() ([]byte, []byte) { + privkey := make([]byte, 32) + if _, err := io.ReadFull(rand.Reader, privkey); err != nil { + panic(err) + } + + pubkey, err := curve25519.X25519(privkey, curve25519.Basepoint) + if err != nil { + panic(err) + } + + return pubkey, privkey +} + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cert/pem.go b/cert/pem.go index 744ae2edf..249b63917 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -30,19 +30,24 @@ func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { return nil, r, ErrInvalidPEMBlock } + var c Certificate + var err error + switch p.Type { case CertificateBanner: - c, err := unmarshalCertificateV1(p.Bytes, true) - if err != nil { - return nil, nil, err - } - return c, r, nil + c, err = unmarshalCertificateV1(p.Bytes, nil) case CertificateV2Banner: - //TODO - panic("TODO") + c, err = unmarshalCertificateV2(p.Bytes, nil, Curve_CURVE25519) default: return nil, r, ErrInvalidPEMCertificateBanner } + + if err != nil { + return nil, r, err + } + + return c, r, nil + } func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { diff --git a/cert/sign.go b/cert/sign.go index e446aa131..a1e09cd2b 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -1,11 +1,16 @@ package cert import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" "fmt" + "math/big" "net/netip" + "slices" "time" - - "github.com/slackhq/nebula/pkclient" ) // TBSCertificate represents a certificate intended to be signed. @@ -24,27 +29,62 @@ type TBSCertificate struct { issuer string } +type beingSignedCertificate interface { + // fromTBSCertificate copies the values from the TBSCertificate to this versions internal representation + fromTBSCertificate(*TBSCertificate) error + + // marshalForSigning returns the bytes that should be signed + marshalForSigning() ([]byte, error) + + // setSignature sets the signature for the certificate that has just been signed. The signature must not be blank. + setSignature([]byte) error +} + +type SignerLambda func(certBytes []byte) ([]byte, error) + // Sign will create a sealed certificate using details provided by the TBSCertificate as long as those // details do not violate constraints of the signing certificate. // If the TBSCertificate is a CA then signer must be nil. func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { - return t.sign(signer, curve, key, nil) -} - -func (t *TBSCertificate) SignPkcs11(signer Certificate, curve Curve, client *pkclient.PKClient) (Certificate, error) { - if curve != Curve_P256 { - return nil, fmt.Errorf("only P256 is supported by PKCS#11") + switch t.Curve { + case Curve_CURVE25519: + pk := ed25519.PrivateKey(key) + sp := func(certBytes []byte) ([]byte, error) { + sig := ed25519.Sign(pk, certBytes) + return sig, nil + } + return t.SignWith(signer, curve, sp) + case Curve_P256: + pk := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + pk.X, pk.Y = pk.Curve.ScalarBaseMult(key) + sp := func(certBytes []byte) ([]byte, error) { + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(certBytes) + return ecdsa.SignASN1(rand.Reader, pk, hashed[:]) + } + return t.SignWith(signer, curve, sp) + default: + return nil, fmt.Errorf("invalid curve: %s", t.Curve) } - - return t.sign(signer, curve, nil, client) } -func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, client *pkclient.PKClient) (Certificate, error) { +// SignWith does the same thing as sign, but uses the function in `sp` to calculate the signature. +// You should only use SignWith if you do not have direct access to your private key. +func (t *TBSCertificate) SignWith(signer Certificate, curve Curve, sp SignerLambda) (Certificate, error) { if curve != t.Curve { return nil, fmt.Errorf("curve in cert and private key supplied don't match") } //TODO: make sure we have all minimum properties to sign, like a public key + //TODO: we need to verify networks and unsafe networks (no duplicates, max of 1 of each version for v2 certs if signer != nil { if t.IsCA { @@ -67,10 +107,54 @@ func (t *TBSCertificate) sign(signer Certificate, curve Curve, key []byte, clien } } + slices.SortFunc(t.Networks, comparePrefix) + slices.SortFunc(t.UnsafeNetworks, comparePrefix) + + var c beingSignedCertificate switch t.Version { case Version1: - return signV1(t, curve, key, client) + c = &certificateV1{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } + case Version2: + c = &certificateV2{} + err := c.fromTBSCertificate(t) + if err != nil { + return nil, err + } default: return nil, fmt.Errorf("unknown cert version %d", t.Version) } + + certBytes, err := c.marshalForSigning() + if err != nil { + return nil, err + } + + sig, err := sp(certBytes) + if err != nil { + return nil, err + } + + err = c.setSignature(sig) + if err != nil { + return nil, err + } + + sc, ok := c.(Certificate) + if !ok { + return nil, fmt.Errorf("invalid certificate") + } + + return sc, nil +} + +func comparePrefix(a, b netip.Prefix) int { + addr := a.Addr().Compare(b.Addr()) + if addr == 0 { + return a.Bits() - b.Bits() + } + return addr } diff --git a/cert/sign_test.go b/cert/sign_test.go new file mode 100644 index 000000000..2b8dbe82f --- /dev/null +++ b/cert/sign_test.go @@ -0,0 +1,90 @@ +package cert + +import ( + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCertificateV1_Sign(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/24"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + } + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) + assert.Nil(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + assert.Nil(t, err) + uc, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + assert.NotNil(t, uc) +} + +func TestCertificateV1_SignP256(t *testing.T) { + before := time.Now().Add(time.Second * -60).Round(time.Second) + after := time.Now().Add(time.Second * 60).Round(time.Second) + pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") + + tbs := TBSCertificate{ + Version: Version1, + Name: "testing", + Networks: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + }, + UnsafeNetworks: []netip.Prefix{ + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), + }, + Groups: []string{"test-group1", "test-group2", "test-group3"}, + NotBefore: before, + NotAfter: after, + PublicKey: pubKey, + IsCA: false, + Curve: Curve_P256, + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + assert.NoError(t, err) + pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) + rawPriv := priv.D.FillBytes(make([]byte, 32)) + + c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) + assert.Nil(t, err) + assert.NotNil(t, c) + assert.True(t, c.CheckSignature(pub)) + + b, err := c.Marshal() + assert.Nil(t, err) + uc, err := unmarshalCertificateV1(b, nil) + assert.Nil(t, err) + assert.NotNil(t, uc) +} diff --git a/e2e/helpers.go b/cert_test/cert.go similarity index 51% rename from e2e/helpers.go rename to cert_test/cert.go index c0893aca2..ebc6f522d 100644 --- a/e2e/helpers.go +++ b/cert_test/cert.go @@ -1,6 +1,9 @@ -package e2e +package cert_test import ( + "crypto/ecdh" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "io" "net/netip" @@ -11,9 +14,26 @@ import ( "golang.org/x/crypto/ed25519" ) -// NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { - pub, priv, err := ed25519.GenerateKey(rand.Reader) +// NewTestCaCert will create a new ca certificate +func NewTestCaCert(version cert.Version, curve cert.Curve, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { + var err error + var pub, priv []byte + + switch curve { + case cert.Curve_CURVE25519: + pub, priv, err = ed25519.GenerateKey(rand.Reader) + case cert.Curve_P256: + privk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + + pub = elliptic.Marshal(elliptic.P256(), privk.PublicKey.X, privk.PublicKey.Y) + priv = privk.D.FillBytes(make([]byte, 32)) + default: + // There is no default to allow the underlying lib to respond with an error + } + if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -22,7 +42,8 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre } t := &cert.TBSCertificate{ - Version: cert.Version1, + Curve: curve, + Version: version, Name: "test ca", NotBefore: time.Unix(before.Unix(), 0), NotAfter: time.Unix(after.Unix(), 0), @@ -33,7 +54,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre IsCA: true, } - c, err := t.Sign(nil, cert.Curve_CURVE25519, priv) + c, err := t.Sign(nil, curve, priv) if err != nil { panic(err) } @@ -48,7 +69,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { +func NewTestCert(v cert.Version, curve cert.Curve, ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -57,9 +78,19 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim after = time.Now().Add(time.Second * 60).Round(time.Second) } - pub, rawPriv := x25519Keypair() + var pub, priv []byte + switch curve { + case cert.Curve_CURVE25519: + pub, priv = X25519Keypair() + case cert.Curve_P256: + pub, priv = P256Keypair() + default: + panic("unknown curve") + } + nc := &cert.TBSCertificate{ - Version: cert.Version1, + Version: v, + Curve: curve, Name: name, Networks: networks, UnsafeNetworks: unsafeNetworks, @@ -80,10 +111,10 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim panic(err) } - return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem + return c, pub, cert.MarshalPrivateKeyToPEM(curve, priv), pem } -func x25519Keypair() ([]byte, []byte) { +func X25519Keypair() ([]byte, []byte) { privkey := make([]byte, 32) if _, err := io.ReadFull(rand.Reader, privkey); err != nil { panic(err) @@ -96,3 +127,12 @@ func x25519Keypair() ([]byte, []byte) { return pubkey, privkey } + +func P256Keypair() ([]byte, []byte) { + privkey, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + panic(err) + } + pubkey := privkey.PublicKey() + return pubkey.Bytes(), privkey.Bytes() +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 90ea8ffa6..f83c94fb4 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -27,34 +27,43 @@ type caFlags struct { outCertPath *string outQRPath *string groups *string - ips *string - subnets *string + networks *string + unsafeNetworks *string argonMemory *uint argonIterations *uint argonParallelism *uint encryption *bool + version *uint curve *string p11url *string + + // Deprecated options + ips *string + subnets *string } func newCaFlags() *caFlags { cf := caFlags{set: flag.NewFlagSet("ca", flag.ContinueOnError)} cf.set.Usage = func() {} cf.name = cf.set.String("name", "", "Required: name of the certificate authority") + cf.version = cf.set.Uint("version", uint(cert.Version2), "Optional: version of the certificate format to use") cf.duration = cf.set.Duration("duration", time.Duration(time.Hour*8760), "Optional: amount of time the certificate should be valid for. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") cf.outKeyPath = cf.set.String("out-key", "ca.key", "Optional: path to write the private key to") cf.outCertPath = cf.set.String("out-crt", "ca.crt", "Optional: path to write the certificate to") cf.outQRPath = cf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") cf.groups = cf.set.String("groups", "", "Optional: comma separated list of groups. This will limit which groups subordinate certs can use") - cf.ips = cf.set.String("ips", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses") - cf.subnets = cf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets") + cf.networks = cf.set.String("networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks") + cf.unsafeNetworks = cf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks") cf.argonMemory = cf.set.Uint("argon-memory", 2*1024*1024, "Optional: Argon2 memory parameter (in KiB) used for encrypted private key passphrase") cf.argonParallelism = cf.set.Uint("argon-parallelism", 4, "Optional: Argon2 parallelism parameter used for encrypted private key passphrase") cf.argonIterations = cf.set.Uint("argon-iterations", 1, "Optional: Argon2 iterations parameter used for encrypted private key passphrase") cf.encryption = cf.set.Bool("encrypt", false, "Optional: prompt for passphrase and write out-key in an encrypted format") cf.curve = cf.set.String("curve", "25519", "EdDSA/ECDSA Curve (25519, P256)") cf.p11url = p11Flag(cf.set) + + cf.ips = cf.set.String("ips", "", "Deprecated, see -networks") + cf.subnets = cf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &cf } @@ -113,36 +122,51 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []netip.Prefix - if *cf.ips != "" { - for _, rs := range strings.Split(*cf.ips, ",") { + version := cert.Version(*cf.version) + if version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + + var networks []netip.Prefix + if *cf.networks == "" && *cf.ips != "" { + // Pull up deprecated -ips flag if needed + *cf.networks = *cf.ips + } + + if *cf.networks != "" { + for _, rs := range strings.Split(*cf.networks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid ip definition: %s", err) + return newHelpErrorf("invalid -networks definition: %s", rs) } - if !n.Addr().Is4() { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4, have %s", rs) } - ips = append(ips, n) + networks = append(networks, n) } } } - var subnets []netip.Prefix - if *cf.subnets != "" { - for _, rs := range strings.Split(*cf.subnets, ",") { + var unsafeNetworks []netip.Prefix + if *cf.unsafeNetworks == "" && *cf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *cf.unsafeNetworks = *cf.subnets + } + + if *cf.unsafeNetworks != "" { + for _, rs := range strings.Split(*cf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", err) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if !n.Addr().Is4() { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + if version == cert.Version1 && !n.Addr().Is4() { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4, have %s", rs) } - subnets = append(subnets, n) + unsafeNetworks = append(unsafeNetworks, n) } } } @@ -222,11 +246,11 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } t := &cert.TBSCertificate{ - Version: cert.Version1, + Version: version, Name: *cf.name, Groups: groups, - Networks: ips, - UnsafeNetworks: subnets, + Networks: networks, + UnsafeNetworks: unsafeNetworks, NotBefore: time.Now(), NotAfter: time.Now().Add(*cf.duration), PublicKey: pub, @@ -248,7 +272,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error var b []byte if isP11 { - c, err = t.SignPkcs11(nil, curve, p11Client) + c, err = t.SignWith(nil, curve, p11Client.SignASN1) if err != nil { return fmt.Errorf("error while signing with PKCS#11: %w", err) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 06a24edd2..d32044340 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -43,9 +43,11 @@ func Test_caHelp(t *testing.T) { " -groups string\n"+ " \tOptional: comma separated list of groups. This will limit which groups subordinate certs can use\n"+ " -ips string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use for ip addresses\n"+ + " Deprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the certificate authority\n"+ + " -networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in networks\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to (default \"ca.crt\")\n"+ " -out-key string\n"+ @@ -54,7 +56,11 @@ func Test_caHelp(t *testing.T) { " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. This will limit which ipv4 addresses and networks subordinate certs can use in subnets\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. This will limit which ip addresses and networks subordinate certs can use in unsafe networks\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use (default 2)\n", ob.String(), ) } @@ -83,25 +89,25 @@ func Test_ca(t *testing.T) { // required args assertHelpError(t, ca( - []string{"-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, + []string{"-version", "1", "-out-key", "nope", "-out-crt", "nope", "duration", "100m"}, ob, eb, nopw, ), "-name is required") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only ips - assertHelpError(t, ca([]string{"-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-ips", "100::100/100"}, ob, eb, nopw), "invalid -networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // ipv4 only subnets - assertHelpError(t, ca([]string{"-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + assertHelpError(t, ca([]string{"-version", "1", "-name", "ipv6", "-subnets", "100::100/100"}, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4, have 100::100/100") assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // failed key write ob.Reset() eb.Reset() - args := []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} + args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -114,7 +120,7 @@ func Test_ca(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -128,7 +134,7 @@ func Test_ca(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -161,7 +167,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -189,7 +195,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -199,7 +205,7 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -209,13 +215,13 @@ func Test_ca(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.Nil(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -224,7 +230,7 @@ func Test_ca(t *testing.T) { os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} + args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index a62c22338..30e0965b9 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -49,6 +49,8 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { var qrBytes []byte part := 0 + var jsonCerts []cert.Certificate + for { c, rawCert, err = cert.UnmarshalCertificateFromPEM(rawCert) if err != nil { @@ -56,13 +58,10 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { } if *pf.json { - b, _ := json.Marshal(c) - out.Write(b) - out.Write([]byte("\n")) - + jsonCerts = append(jsonCerts, c) } else { - out.Write([]byte(c.String())) - out.Write([]byte("\n")) + _, _ = out.Write([]byte(c.String())) + _, _ = out.Write([]byte("\n")) } if *pf.outQRPath != "" { @@ -80,6 +79,12 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { part++ } + if *pf.json { + b, _ := json.Marshal(jsonCerts) + _, _ = out.Write(b) + _, _ = out.Write([]byte("\n")) + } + if *pf.outQRPath != "" { b, err := qrcode.Encode(string(qrBytes), qrcode.Medium, -5) if err != nil { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 4c9a72db4..a6737543b 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -87,7 +87,65 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", + `{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +{ + "details": { + "curve": "CURVE25519", + "groups": [ + "hi" + ], + "isCa": false, + "issuer": "`+c.Issuer()+`", + "name": "test", + "networks": [], + "notAfter": "0001-01-01T00:00:00Z", + "notBefore": "0001-01-01T00:00:00Z", + "publicKey": "`+pk+`", + "unsafeNetworks": [] + }, + "fingerprint": "`+fp+`", + "signature": "`+sig+`", + "version": 1 +} +`, ob.String(), ) assert.Equal(t, "", eb.String()) @@ -108,7 +166,8 @@ func Test_printCert(t *testing.T) { assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"hi\"],\"ips\":[],\"isCa\":false,\"issuer\":\""+c.Issuer()+"\",\"name\":\"test\",\"notAfter\":\"0001-01-01T00:00:00Z\",\"notBefore\":\"0001-01-01T00:00:00Z\",\"publicKey\":\""+pk+"\",\"subnets\":[]},\"fingerprint\":\""+fp+"\",\"signature\":\""+sig+"\"}\n", + `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":[],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] +`, ob.String(), ) assert.Equal(t, "", eb.String()) diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 13e807f3b..253ef864e 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -3,6 +3,7 @@ package main import ( "crypto/ecdh" "crypto/rand" + "errors" "flag" "fmt" "io" @@ -18,36 +19,46 @@ import ( ) type signFlags struct { - set *flag.FlagSet - caKeyPath *string - caCertPath *string - name *string - ip *string - duration *time.Duration - inPubPath *string - outKeyPath *string - outCertPath *string - outQRPath *string - groups *string - subnets *string - p11url *string + set *flag.FlagSet + version *uint + caKeyPath *string + caCertPath *string + name *string + networks *string + unsafeNetworks *string + duration *time.Duration + inPubPath *string + outKeyPath *string + outCertPath *string + outQRPath *string + groups *string + + p11url *string + + // Deprecated options + ip *string + subnets *string } func newSignFlags() *signFlags { sf := signFlags{set: flag.NewFlagSet("sign", flag.ContinueOnError)} sf.set.Usage = func() {} + sf.version = sf.set.Uint("version", 0, "Optional: version of the certificate format to use, the default is to create both v1 and v2 certificates.") sf.caKeyPath = sf.set.String("ca-key", "ca.key", "Optional: path to the signing CA key") sf.caCertPath = sf.set.String("ca-crt", "ca.crt", "Optional: path to the signing CA cert") sf.name = sf.set.String("name", "", "Required: name of the cert, usually a hostname") - sf.ip = sf.set.String("ip", "", "Required: ipv4 address and network in CIDR notation to assign the cert") + sf.networks = sf.set.String("networks", "", "Required: comma separated list of ip address and network in CIDR notation to assign to this cert") + sf.unsafeNetworks = sf.set.String("unsafe-networks", "", "Optional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for") sf.duration = sf.set.Duration("duration", 0, "Optional: how long the cert should be valid for. The default is 1 second before the signing cert expires. Valid time units are seconds: \"s\", minutes: \"m\", hours: \"h\"") sf.inPubPath = sf.set.String("in-pub", "", "Optional (if out-key not set): path to read a previously generated public key") sf.outKeyPath = sf.set.String("out-key", "", "Optional (if in-pub not set): path to write the private key to") sf.outCertPath = sf.set.String("out-crt", "", "Optional: path to write the certificate to") sf.outQRPath = sf.set.String("out-qr", "", "Optional: output a qr code image (png) of the certificate") sf.groups = sf.set.String("groups", "", "Optional: comma separated list of groups") - sf.subnets = sf.set.String("subnets", "", "Optional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for") sf.p11url = p11Flag(sf.set) + + sf.ip = sf.set.String("ip", "", "Deprecated, see -networks") + sf.subnets = sf.set.String("subnets", "", "Deprecated, see -unsafe-networks") return &sf } @@ -71,13 +82,26 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) if err := mustFlagString("name", sf.name); err != nil { return err } - if err := mustFlagString("ip", sf.ip); err != nil { - return err - } if !isP11 && *sf.inPubPath != "" && *sf.outKeyPath != "" { return newHelpErrorf("cannot set both -in-pub and -out-key") } + var v4Networks []netip.Prefix + var v6Networks []netip.Prefix + if *sf.networks == "" && *sf.ip != "" { + // Pull up deprecated -ip flag if needed + *sf.networks = *sf.ip + } + + if len(*sf.networks) == 0 { + return newHelpErrorf("-networks is required") + } + + version := cert.Version(*sf.version) + if version != 0 && version != cert.Version1 && version != cert.Version2 { + return newHelpErrorf("-version must be either %v or %v", cert.Version1, cert.Version2) + } + var curve cert.Curve var caKey []byte @@ -91,14 +115,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // naively attempt to decode the private key as though it is not encrypted caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { + if errors.Is(err, cert.ErrPrivateKeyEncrypted) { // ask for a passphrase until we get one var passphrase []byte for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() - if err == ErrNoTerminal { + if errors.Is(err, ErrNoTerminal) { return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") } else if err != nil { return fmt.Errorf("error reading password: %s", err) @@ -146,37 +170,57 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - network, err := netip.ParsePrefix(*sf.ip) - if err != nil { - return newHelpErrorf("invalid ip definition: %s", *sf.ip) - } - if !network.Addr().Is4() { - return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) - } + if *sf.networks != "" { + for _, rs := range strings.Split(*sf.networks, ",") { + //TODO: error on duplicates? Mainly only addr matters, having two of the same addr in the same or different prefix space is strange + rs := strings.Trim(rs, " ") + if rs != "" { + n, err := netip.ParsePrefix(rs) + if err != nil { + return newHelpErrorf("invalid -networks definition: %s", rs) + } - var groups []string - if *sf.groups != "" { - for _, rg := range strings.Split(*sf.groups, ",") { - g := strings.TrimSpace(rg) - if g != "" { - groups = append(groups, g) + if n.Addr().Is4() { + v4Networks = append(v4Networks, n) + } else { + v6Networks = append(v6Networks, n) + } } } } - var subnets []netip.Prefix - if *sf.subnets != "" { - for _, rs := range strings.Split(*sf.subnets, ",") { + var v4UnsafeNetworks []netip.Prefix + var v6UnsafeNetworks []netip.Prefix + if *sf.unsafeNetworks == "" && *sf.subnets != "" { + // Pull up deprecated -subnets flag if needed + *sf.unsafeNetworks = *sf.subnets + } + + if *sf.unsafeNetworks != "" { + //TODO: error on duplicates? + for _, rs := range strings.Split(*sf.unsafeNetworks, ",") { rs := strings.Trim(rs, " ") if rs != "" { - s, err := netip.ParsePrefix(rs) + n, err := netip.ParsePrefix(rs) if err != nil { - return newHelpErrorf("invalid subnet definition: %s", rs) + return newHelpErrorf("invalid -unsafe-networks definition: %s", rs) } - if !s.Addr().Is4() { - return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) + + if n.Addr().Is4() { + v4UnsafeNetworks = append(v4UnsafeNetworks, n) + } else { + v6UnsafeNetworks = append(v6UnsafeNetworks, n) } - subnets = append(subnets, s) + } + } + } + + var groups []string + if *sf.groups != "" { + for _, rg := range strings.Split(*sf.groups, ",") { + g := strings.TrimSpace(rg) + if g != "" { + groups = append(groups, g) } } } @@ -218,19 +262,6 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - t := &cert.TBSCertificate{ - Version: cert.Version1, - Name: *sf.name, - Networks: []netip.Prefix{network}, - Groups: groups, - UnsafeNetworks: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Curve: curve, - } - if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" } @@ -243,18 +274,85 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - var c cert.Certificate + var crts []cert.Certificate - if p11Client == nil { - c, err = t.Sign(caCert, curve, caKey) - if err != nil { - return fmt.Errorf("error while signing: %w", err) + notBefore := time.Now() + notAfter := notBefore.Add(*sf.duration) + + if version == 0 || version == cert.Version1 { + // Make sure we at least have an ip + if len(v4Networks) != 1 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only have a single ipv4 address") } - } else { - c, err = t.SignPkcs11(caCert, curve, p11Client) - if err != nil { - return fmt.Errorf("error while signing with PKCS#11: %w", err) + + if version == cert.Version1 { + // If we are asked to mint a v1 certificate only then we cant just ignore any v6 addresses + if len(v6Networks) > 0 { + return newHelpErrorf("invalid -networks definition: v1 certificates can only be ipv4") + } + + if len(v6UnsafeNetworks) > 0 { + return newHelpErrorf("invalid -unsafe-networks definition: v1 certificates can only be ipv4") + } + } + + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{v4Networks[0]}, + Groups: groups, + UnsafeNetworks: v4UnsafeNetworks, + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } } + + crts = append(crts, nc) + } + + if version == 0 || version == cert.Version2 { + t := &cert.TBSCertificate{ + Version: cert.Version2, + Name: *sf.name, + Networks: append(v4Networks, v6Networks...), + Groups: groups, + UnsafeNetworks: append(v4UnsafeNetworks, v6UnsafeNetworks...), + NotBefore: notBefore, + NotAfter: notAfter, + PublicKey: pub, + IsCA: false, + Curve: curve, + } + + var nc cert.Certificate + if p11Client == nil { + nc, err = t.Sign(caCert, curve, caKey) + if err != nil { + return fmt.Errorf("error while signing: %w", err) + } + } else { + nc, err = t.SignWith(caCert, curve, p11Client.SignASN1) + if err != nil { + return fmt.Errorf("error while signing with PKCS#11: %w", err) + } + } + + crts = append(crts, nc) } if !isP11 && *sf.inPubPath == "" { @@ -268,9 +366,13 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - b, err := c.MarshalPEM() - if err != nil { - return fmt.Errorf("error while marshalling certificate: %s", err) + var b []byte + for _, c := range crts { + sb, err := c.MarshalPEM() + if err != nil { + return fmt.Errorf("error while marshalling certificate: %s", err) + } + b = append(b, sb...) } err = os.WriteFile(*sf.outCertPath, b, 0600) diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index b68434df7..b4fdc43be 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -39,9 +39,11 @@ func Test_signHelp(t *testing.T) { " -in-pub string\n"+ " \tOptional (if out-key not set): path to read a previously generated public key\n"+ " -ip string\n"+ - " \tRequired: ipv4 address and network in CIDR notation to assign the cert\n"+ + " \tDeprecated, see -networks\n"+ " -name string\n"+ " \tRequired: name of the cert, usually a hostname\n"+ + " -networks string\n"+ + " \tRequired: comma separated list of ip address and network in CIDR notation to assign to this cert\n"+ " -out-crt string\n"+ " \tOptional: path to write the certificate to\n"+ " -out-key string\n"+ @@ -50,7 +52,11 @@ func Test_signHelp(t *testing.T) { " \tOptional: output a qr code image (png) of the certificate\n"+ optionalPkcs11String(" -pkcs11 string\n \tOptional: PKCS#11 URI to an existing private key\n")+ " -subnets string\n"+ - " \tOptional: comma separated list of ipv4 address and network in CIDR notation. Subnets this cert can serve for\n", + " \tDeprecated, see -unsafe-networks\n"+ + " -unsafe-networks string\n"+ + " \tOptional: comma separated list of ip address and network in CIDR notation. Unsafe networks this cert can route for\n"+ + " -version uint\n"+ + " \tOptional: version of the certificate format to use, the default is to create both v1 and v2 certificates.\n", ob.String(), ) } @@ -77,20 +83,20 @@ func Test_signCert(t *testing.T) { // required args assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-ip", "1.1.1.1/24", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, ), "-name is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, - ), "-ip is required") + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-out-key", "nope", "-out-crt", "nope"}, ob, eb, nopw, + ), "-networks is required") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // cannot set -in-pub and -out-key assertHelpError(t, signCert( - []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, + []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-in-pub", "nope", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope"}, ob, eb, nopw, ), "cannot set both -in-pub and -out-key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -98,7 +104,7 @@ func Test_signCert(t *testing.T) { // failed to read key ob.Reset() eb.Reset() - args := []string{"-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key @@ -108,7 +114,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caKeyF.Name()) - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -120,7 +126,7 @@ func Test_signCert(t *testing.T) { caKeyF.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv)) // failed to read cert - args = []string{"-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -132,7 +138,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(caCrtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -143,7 +149,7 @@ func Test_signCert(t *testing.T) { caCrtF.Write(b) // failed to read pub - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -155,7 +161,7 @@ func Test_signCert(t *testing.T) { assert.Nil(t, err) defer os.Remove(inPubF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -169,30 +175,37 @@ func Test_signCert(t *testing.T) { // bad ip cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: a1.1.1.1/24") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "a1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: a1.1.1.1/24") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid ip definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "100::100/100", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") + assert.Empty(t, ob.String()) + assert.Empty(t, eb.String()) + + ob.Reset() + eb.Reset() + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24,1.1.1.2/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -networks definition: v1 certificates can only have a single ipv4 address") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // bad subnet cidr ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: a") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: a") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} - assertHelpError(t, signCert(args, ob, eb, nopw), "invalid subnet definition: can only be ipv4, have 100::100/100") + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "100::100/100"} + assertHelpError(t, signCert(args, ob, eb, nopw), "invalid -unsafe-networks definition: v1 certificates can only be ipv4") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -205,7 +218,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -213,7 +226,7 @@ func Test_signCert(t *testing.T) { // failed key write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -226,7 +239,7 @@ func Test_signCert(t *testing.T) { // failed cert write ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -240,7 +253,7 @@ func Test_signCert(t *testing.T) { // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -283,7 +296,7 @@ func Test_signCert(t *testing.T) { os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} assert.Nil(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -300,7 +313,7 @@ func Test_signCert(t *testing.T) { eb.Reset() os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -308,14 +321,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -323,14 +336,14 @@ func Test_signCert(t *testing.T) { // create valid cert/key for overwrite tests os.Remove(keyF.Name()) os.Remove(crtF.Name()) - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -362,7 +375,7 @@ func Test_signCert(t *testing.T) { caCrtF.Write(b) // test with the proper password - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Nil(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -372,7 +385,7 @@ func Test_signCert(t *testing.T) { eb.Reset() testpw.password = []byte("invalid password") - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -381,7 +394,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) @@ -391,7 +404,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() - args = []string{"-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} + args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} assert.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index 80cfef3c0..bea4d1d94 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -1,6 +1,7 @@ package main import ( + "errors" "flag" "fmt" "io" @@ -41,14 +42,14 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCACert, err := os.ReadFile(*vf.caPath) if err != nil { - return fmt.Errorf("error while reading ca: %s", err) + return fmt.Errorf("error while reading ca: %w", err) } caPool := cert.NewCAPool() for { rawCACert, err = caPool.AddCAFromPEM(rawCACert) if err != nil { - return fmt.Errorf("error while adding ca cert to pool: %s", err) + return fmt.Errorf("error while adding ca cert to pool: %w", err) } if rawCACert == nil || len(rawCACert) == 0 || strings.TrimSpace(string(rawCACert)) == "" { @@ -58,20 +59,30 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { rawCert, err := os.ReadFile(*vf.certPath) if err != nil { - return fmt.Errorf("unable to read crt; %s", err) + return fmt.Errorf("unable to read crt: %w", err) } - - c, _, err := cert.UnmarshalCertificateFromPEM(rawCert) - if err != nil { - return fmt.Errorf("error while parsing crt: %s", err) - } - - _, err = caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return err + var errs []error + for { + if len(rawCert) == 0 { + break + } + c, extra, err := cert.UnmarshalCertificateFromPEM(rawCert) + if err != nil { + return fmt.Errorf("error while parsing crt: %w", err) + } + rawCert = extra + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { + switch { + case errors.Is(err, cert.ErrCaNotFound): + errs = append(errs, fmt.Errorf("error while verifying certificate v%d %s with issuer %s: %w", c.Version(), c.Name(), c.Issuer(), err)) + default: + errs = append(errs, fmt.Errorf("error while verifying certificate %+v: %w", c, err)) + } + } } - return nil + return errors.Join(errs...) } func verifySummary() string { @@ -80,7 +91,7 @@ func verifySummary() string { func verifyHelp(out io.Writer) { vf := newVerifyFlags() - out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) + _, _ = out.Write([]byte("Usage of " + os.Args[0] + " " + verifySummary() + "\n")) vf.set.SetOutput(out) vf.set.PrintDefaults() } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index 204ff09ff..d94bd1fcb 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -3,10 +3,12 @@ package main import ( "bytes" "crypto/rand" + "errors" "os" "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" "golang.org/x/crypto/ed25519" ) @@ -76,7 +78,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt; open does_not_exist: "+NoSuchFileError) + assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() @@ -106,7 +108,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "certificate signature did not match") + assert.True(t, errors.Is(err, cert.ErrSignatureMismatch)) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) diff --git a/connection_manager.go b/connection_manager.go index 77182523f..9d8d0715f 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -183,7 +183,7 @@ func (n *connectionManager) doTrafficCheck(localIndex uint32, p, nb, out []byte, case deleteTunnel: if n.hostMap.DeleteHostInfo(hostinfo) { // Only clearing the lighthouse cache if this is the last hostinfo for this vpn ip in the hostmap - n.intf.lightHouse.DeleteVpnIp(hostinfo.vpnIp) + n.intf.lightHouse.DeleteVpnAddrs(hostinfo.vpnAddrs) } case closeTunnel: @@ -221,7 +221,7 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) relayFor := oldhostinfo.relayState.CopyAllRelayFor() for _, r := range relayFor { - existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerIp) + existing, ok := newhostinfo.relayState.QueryRelayForByIp(r.PeerAddr) var index uint32 var relayFrom netip.Addr @@ -235,11 +235,11 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) index = existing.LocalIndex switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = existing.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = existing.PeerAddr case ForwardingType: - relayFrom = existing.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = existing.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } @@ -253,45 +253,64 @@ func (n *connectionManager) migrateRelayUsed(oldhostinfo, newhostinfo *HostInfo) n.relayUsedLock.RUnlock() // The relay doesn't exist at all; create some relay state and send the request. var err error - index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerIp, nil, r.Type, Requested) + index, err = AddRelay(n.l, newhostinfo, n.hostMap, r.PeerAddr, nil, r.Type, Requested) if err != nil { n.l.WithError(err).Error("failed to migrate relay to new hostinfo") continue } switch r.Type { case TerminalType: - relayFrom = n.intf.myVpnNet.Addr() - relayTo = r.PeerIp + relayFrom = n.intf.myVpnAddrs[0] + relayTo = r.PeerAddr case ForwardingType: - relayFrom = r.PeerIp - relayTo = newhostinfo.vpnIp + relayFrom = r.PeerAddr + relayTo = newhostinfo.vpnAddrs[0] default: // should never happen } } - //TODO: IPV6-WORK - relayFromB := relayFrom.As4() - relayToB := relayTo.As4() - // Send a CreateRelayRequest to the peer. req := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(relayFromB[:]), - RelayToIp: binary.BigEndian.Uint32(relayToB[:]), } + + switch newhostinfo.GetCert().Certificate.Version() { + case cert.Version1: + if !relayFrom.Is4() { + n.l.Error("can not migrate v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !relayTo.Is4() { + n.l.Error("can not migrate v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := relayFrom.As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = relayTo.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + req.RelayFromAddr = netAddrToProtoAddr(relayFrom) + req.RelayToAddr = netAddrToProtoAddr(relayTo) + default: + newhostinfo.logger(n.l).Error("Unknown certificate version found while attempting to migrate relay") + continue + } + msg, err := req.Marshal() if err != nil { n.l.WithError(err).Error("failed to marshal Control message to migrate relay") } else { n.intf.SendMessageToHostInfo(header.Control, 0, newhostinfo, msg, make([]byte, 12), make([]byte, mtu)) n.l.WithFields(logrus.Fields{ - "relayFrom": req.RelayFromIp, - "relayTo": req.RelayToIp, + "relayFrom": req.RelayFromAddr, + "relayTo": req.RelayToAddr, "initiatorRelayIndex": req.InitiatorRelayIndex, "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": newhostinfo.vpnIp}). + "vpnAddrs": newhostinfo.vpnAddrs}). Info("send CreateRelayRequest") } } @@ -313,7 +332,7 @@ func (n *connectionManager) makeTrafficDecision(localIndex uint32, now time.Time return closeTunnel, hostinfo, nil } - primary := n.hostMap.Hosts[hostinfo.vpnIp] + primary := n.hostMap.Hosts[hostinfo.vpnAddrs[0]] mainHostInfo := true if primary != nil && primary != hostinfo { mainHostInfo = false @@ -407,21 +426,24 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { // If we are here then we have multiple tunnels for a host pair and neither side believes the same tunnel is primary. // Let's sort this out. - if current.vpnIp.Compare(n.intf.myVpnNet.Addr()) < 0 { - // Only one side should flip primary because if both flip then we may never resolve to a single tunnel. - // vpn ip is static across all tunnels for this host pair so lets use that to determine who is flipping. - // The remotes vpn ip is lower than mine. I will not flip. + // Only one side should swap because if both swap then we may never resolve to a single tunnel. + // vpn addr is static across all tunnels for this host pair so lets + // use that to determine if we should consider swapping. + if current.vpnAddrs[0].Compare(n.intf.myVpnAddrs[0]) < 0 { + // Their primary vpn addr is less than mine. Do not swap. return false } - certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) + crt := n.intf.pki.getCertState().getCertificate(current.ConnectionState.myCert.Version()) + // If this tunnel is using the latest certificate then we should swap it to primary for a bit and see if things + // settle down. + return bytes.Equal(current.ConnectionState.myCert.Signature(), crt.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { n.hostMap.Lock() // Make sure the primary is still the same after the write lock. This avoids a race with a rehandshake. - if n.hostMap.Hosts[current.vpnIp] == primary { + if n.hostMap.Hosts[current.vpnAddrs[0]] == primary { n.hostMap.unlockedMakePrimary(current) } n.hostMap.Unlock() @@ -473,14 +495,17 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { } func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { - certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { + cs := n.intf.pki.getCertState() + curCrt := hostinfo.ConnectionState.myCert + myCrt := cs.getCertificate(curCrt.Version()) + if curCrt.Version() >= cs.defaultVersion && bytes.Equal(curCrt.Signature(), myCrt.Signature()) == true { + // The current tunnel is using the latest certificate and version, no need to rehandshake. return } - n.l.WithField("vpnIp", hostinfo.vpnIp). + n.l.WithField("vpnAddrs", hostinfo.vpnAddrs). WithField("reason", "local certificate is not current"). Info("Re-handshaking with remote") - n.intf.handshakeManager.StartHandshake(hostinfo.vpnIp, nil) + n.intf.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], nil) } diff --git a/connection_manager_test.go b/connection_manager_test.go index 9f222c8b4..8e2ef15ad 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -34,20 +34,19 @@ func newTestLighthouse() *LightHouse { func Test_NewConnectionManagerTest(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -74,12 +73,12 @@ func Test_NewConnectionManagerTest(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &dummyCert{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -88,7 +87,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) assert.Contains(t, nc.out, hostinfo.localIndexId) @@ -105,32 +104,31 @@ func Test_NewConnectionManagerTest(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // Do a final traffic check tick, the host should now be removed nc.doTrafficCheck(hostinfo.localIndexId, p, nb, out, time.Now()) assert.NotContains(t, nc.pendingDeletion, hostinfo.localIndexId) - assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.NotContains(t, nc.hostMap.Indexes, hostinfo.localIndexId) } func Test_NewConnectionManagerTest2(t *testing.T) { l := test.NewLogger() //_, tuncidr, _ := net.ParseCIDR("1.1.1.1/24") - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} // Very incomplete mock objects - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -157,12 +155,12 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // Add an ip we have established a connection w/ to hostmap hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, localIndexId: 1099, remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &dummyCert{}, + myCert: &dummyCert{version: cert.Version1}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -170,8 +168,8 @@ func Test_NewConnectionManagerTest2(t *testing.T) { // We saw traffic out to vpnIp nc.Out(hostinfo.localIndexId) nc.In(hostinfo.localIndexId) - assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnIp) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.NotContains(t, nc.pendingDeletion, hostinfo.vpnAddrs[0]) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) // Do a traffic check tick, should not be pending deletion but should not have any in/out packets recorded @@ -187,7 +185,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) // We saw traffic, should no longer be pending deletion nc.In(hostinfo.localIndexId) @@ -196,7 +194,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { assert.NotContains(t, nc.out, hostinfo.localIndexId) assert.NotContains(t, nc.in, hostinfo.localIndexId) assert.Contains(t, nc.hostMap.Indexes, hostinfo.localIndexId) - assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnIp) + assert.Contains(t, nc.hostMap.Hosts, hostinfo.vpnAddrs[0]) } // Check if we can disconnect the peer. @@ -210,7 +208,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { localrange := netip.MustParsePrefix("10.1.1.1/24") vpnIp := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - hostMap := newHostMap(l, vpncidr) + hostMap := newHostMap(l) hostMap.preferredRanges.Store(&preferredRanges) // Generate keys for CA and peer's cert. @@ -244,10 +242,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + privateKey: []byte{}, + v1Cert: &dummyCert{}, + v1HandshakeBytes: []byte{}, } lh := newTestLighthouse() @@ -273,7 +270,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { ifce.connectionManager = nc hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, ConnectionState: &ConnectionState{ myCert: &dummyCert{}, peerCert: cachedPeerCert, diff --git a/connection_state.go b/connection_state.go index bcc9e5d9a..faee443de 100644 --- a/connection_state.go +++ b/connection_state.go @@ -3,6 +3,7 @@ package nebula import ( "crypto/rand" "encoding/json" + "fmt" "sync" "sync/atomic" @@ -26,46 +27,46 @@ type ConnectionState struct { writeLock sync.Mutex } -func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { +func NewConnectionState(l *logrus.Logger, cs *CertState, crt cert.Certificate, initiator bool, pattern noise.HandshakePattern) (*ConnectionState, error) { var dhFunc noise.DHFunc - switch certState.Certificate.Curve() { + switch crt.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: - if certState.pkcs11Backed { + if cs.pkcs11Backed { dhFunc = noiseutil.DHP256PKCS11 } else { dhFunc = noiseutil.DHP256 } default: - l.Errorf("invalid curve: %s", certState.Certificate.Curve()) - return nil + return nil, fmt.Errorf("invalid curve: %s", crt.Curve()) } - var cs noise.CipherSuite - if cipher == "chachapoly" { - cs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) + var ncs noise.CipherSuite + if cs.cipher == "chachapoly" { + ncs = noise.NewCipherSuite(dhFunc, noise.CipherChaChaPoly, noise.HashSHA256) } else { - cs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) + ncs = noise.NewCipherSuite(dhFunc, noiseutil.CipherAESGCM, noise.HashSHA256) } - static := noise.DHKey{Private: certState.PrivateKey, Public: certState.PublicKey} + static := noise.DHKey{Private: cs.privateKey, Public: crt.PublicKey()} b := NewBits(ReplayWindow) - // Clear out bit 0, we never transmit it and we don't want it showing as packet loss + // Clear out bit 0, we never transmit it, and we don't want it showing as packet loss b.Update(l, 0) hs, err := noise.NewHandshakeState(noise.Config{ - CipherSuite: cs, - Random: rand.Reader, - Pattern: pattern, - Initiator: initiator, - StaticKeypair: static, - PresharedKey: psk, - PresharedKeyPlacement: pskStage, + CipherSuite: ncs, + Random: rand.Reader, + Pattern: pattern, + Initiator: initiator, + StaticKeypair: static, + //NOTE: These should come from CertState (pki.go) when we finally implement it + PresharedKey: []byte{}, + PresharedKeyPlacement: 0, }) if err != nil { - return nil + return nil, fmt.Errorf("NewConnectionState: %s", err) } // The queue and ready params prevent a counter race that would happen when @@ -74,12 +75,12 @@ func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, i H: hs, initiator: initiator, window: b, - myCert: certState.Certificate, + myCert: crt, } // always start the counter from 2, as packet 1 and packet 2 are handshake packets. ci.messageCounter.Add(2) - return ci + return ci, nil } func (cs *ConnectionState) MarshalJSON() ([]byte, error) { @@ -89,3 +90,7 @@ func (cs *ConnectionState) MarshalJSON() ([]byte, error) { "message_counter": cs.messageCounter.Load(), }) } + +func (cs *ConnectionState) Curve() cert.Curve { + return cs.myCert.Curve() +} diff --git a/control.go b/control.go index 26159845c..5302e564a 100644 --- a/control.go +++ b/control.go @@ -19,9 +19,9 @@ import ( type controlEach func(h *HostInfo) type controlHostLister interface { - QueryVpnIp(vpnIp netip.Addr) *HostInfo + QueryVpnAddr(vpnAddr netip.Addr) *HostInfo ForEachIndex(each controlEach) - ForEachVpnIp(each controlEach) + ForEachVpnAddr(each controlEach) GetPreferredRanges() []netip.Prefix } @@ -37,7 +37,7 @@ type Control struct { } type ControlHostInfo struct { - VpnIp netip.Addr `json:"vpnIp"` + VpnAddrs []netip.Addr `json:"vpnAddrs"` LocalIndex uint32 `json:"localIndex"` RemoteIndex uint32 `json:"remoteIndex"` RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` @@ -131,10 +131,13 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { - if c.f.myVpnNet.Addr() == vpnIp { - return c.f.pki.GetCertState().Certificate.Copy() + _, found := c.f.myVpnAddrsTable.Lookup(vpnIp) + if found { + // Only returning the default certificate since its impossible + // for any other host but ourselves to have more than 1 + return c.f.pki.getCertState().GetDefaultCertificate().Copy() } - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } @@ -148,7 +151,7 @@ func (c *Control) CreateTunnel(vpnIp netip.Addr) { // PrintTunnel creates a new tunnel to the given vpn ip. func (c *Control) PrintTunnel(vpnIp netip.Addr) *ControlHostInfo { - hi := c.f.hostMap.QueryVpnIp(vpnIp) + hi := c.f.hostMap.QueryVpnAddr(vpnIp) if hi == nil { return nil } @@ -165,9 +168,9 @@ func (c *Control) QueryLighthouse(vpnIp netip.Addr) *CacheMap { return hi.CopyCache() } -// GetHostInfoByVpnIp returns a single tunnels hostInfo, or nil if not found +// GetHostInfoByVpnAddr returns a single tunnels hostInfo, or nil if not found // Caller should take care to Unmap() any 4in6 addresses prior to calling. -func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHostInfo { +func (c *Control) GetHostInfoByVpnAddr(vpnAddr netip.Addr, pending bool) *ControlHostInfo { var hl controlHostLister if pending { hl = c.f.handshakeManager @@ -175,7 +178,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos hl = c.f.hostMap } - h := hl.QueryVpnIp(vpnIp) + h := hl.QueryVpnAddr(vpnAddr) if h == nil { return nil } @@ -187,7 +190,7 @@ func (c *Control) GetHostInfoByVpnIp(vpnIp netip.Addr, pending bool) *ControlHos // SetRemoteForTunnel forces a tunnel to use a specific remote // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *ControlHostInfo { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return nil } @@ -200,7 +203,7 @@ func (c *Control) SetRemoteForTunnel(vpnIp netip.Addr, addr netip.AddrPort) *Con // CloseTunnel closes a fully established tunnel. If localOnly is false it will notify the remote end as well. // Caller should take care to Unmap() any 4in6 addresses prior to calling. func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { - hostInfo := c.f.hostMap.QueryVpnIp(vpnIp) + hostInfo := c.f.hostMap.QueryVpnAddr(vpnIp) if hostInfo == nil { return false } @@ -225,18 +228,14 @@ func (c *Control) CloseTunnel(vpnIp netip.Addr, localOnly bool) bool { // the int returned is a count of tunnels closed func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { //TODO: this is probably better as a function in ConnectionManager or HostMap directly - lighthouses := c.f.lightHouse.GetLighthouses() - shutdown := func(h *HostInfo) { - if excludeLighthouses { - if _, ok := lighthouses[h.vpnIp]; ok { - return - } + if excludeLighthouses && c.f.lightHouse.IsAnyLighthouseAddr(h.vpnAddrs) { + return } c.f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) c.f.closeTunnel(h) - c.l.WithField("vpnIp", h.vpnIp).WithField("udpAddr", h.remote). + c.l.WithField("vpnIp", h.vpnAddrs[0]).WithField("udpAddr", h.remote). Debug("Sending close tunnel message") closed++ } @@ -246,7 +245,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Relays map c.f.hostMap.Lock() for _, relayingHost := range c.f.hostMap.Relays { - relayingHosts[relayingHost.vpnIp] = relayingHost + relayingHosts[relayingHost.vpnAddrs[0]] = relayingHost } c.f.hostMap.Unlock() @@ -254,7 +253,7 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) { // Grab the hostMap lock to access the Hosts map c.f.hostMap.Lock() for _, relayHost := range c.f.hostMap.Indexes { - if _, ok := relayingHosts[relayHost.vpnIp]; !ok { + if _, ok := relayingHosts[relayHost.vpnAddrs[0]]; !ok { hostInfos = append(hostInfos, relayHost) } } @@ -274,9 +273,8 @@ func (c *Control) Device() overlay.Device { } func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { - chi := ControlHostInfo{ - VpnIp: h.vpnIp, + VpnAddrs: make([]netip.Addr, len(h.vpnAddrs)), LocalIndex: h.localIndexId, RemoteIndex: h.remoteIndexId, RemoteAddrs: h.remotes.CopyAddrs(preferredRanges), @@ -285,6 +283,10 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { CurrentRemote: h.remote, } + for i, a := range h.vpnAddrs { + chi.VpnAddrs[i] = a + } + if h.ConnectionState != nil { chi.MessageCounter = h.ConnectionState.messageCounter.Load() } @@ -299,7 +301,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { func listHostMapHosts(hl controlHostLister) []ControlHostInfo { hosts := make([]ControlHostInfo, 0) pr := hl.GetPreferredRanges() - hl.ForEachVpnIp(func(hostinfo *HostInfo) { + hl.ForEachVpnAddr(func(hostinfo *HostInfo) { hosts = append(hosts, copyHostInfo(hostinfo, pr)) }) return hosts diff --git a/control_test.go b/control_test.go index fdfc0a57e..0f21ed92a 100644 --- a/control_test.go +++ b/control_test.go @@ -19,7 +19,7 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l := test.NewLogger() // Special care must be taken to re-use all objects provided to the hostmap and certificate in the expectedInfo object // To properly ensure we are not exposing core memory to the caller - hm := newHostMap(l, netip.Prefix{}) + hm := newHostMap(l) hm.preferredRanges.Store(&[]netip.Prefix{}) remote1 := netip.MustParseAddrPort("0.0.0.100:4444") @@ -35,9 +35,9 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { Mask: net.IPMask{255, 255, 255, 0}, } - remotes := NewRemoteList(nil) - remotes.unlockedPrependV4(netip.IPv4Unspecified(), NewIp4AndPortFromNetIP(remote1.Addr(), remote1.Port())) - remotes.unlockedPrependV6(netip.IPv4Unspecified(), NewIp6AndPortFromNetIP(remote2.Addr(), remote2.Port())) + remotes := NewRemoteList([]netip.Addr{netip.IPv4Unspecified()}, nil) + remotes.unlockedPrependV4(netip.IPv4Unspecified(), netAddrToProtoV4AddrPort(remote1.Addr(), remote1.Port())) + remotes.unlockedPrependV6(netip.IPv4Unspecified(), netAddrToProtoV6AddrPort(remote2.Addr(), remote2.Port())) vpnIp, ok := netip.AddrFromSlice(ipNet.IP) assert.True(t, ok) @@ -51,11 +51,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnIp}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -70,11 +70,11 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { }, remoteIndexId: 200, localIndexId: 201, - vpnIp: vpnIp2, + vpnAddrs: []netip.Addr{vpnIp2}, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, }, &Interface{}) @@ -85,10 +85,10 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { l: logrus.New(), } - thi := c.GetHostInfoByVpnIp(vpnIp, false) + thi := c.GetHostInfoByVpnAddr(vpnIp, false) expectedInfo := ControlHostInfo{ - VpnIp: vpnIp, + VpnAddrs: []netip.Addr{vpnIp}, LocalIndex: 201, RemoteIndex: 200, RemoteAddrs: []netip.AddrPort{remote2, remote1}, @@ -100,13 +100,13 @@ func TestControl_GetHostInfoByVpnIp(t *testing.T) { } // Make sure we don't have any unexpected fields - assertFields(t, []string{"VpnIp", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) + assertFields(t, []string{"VpnAddrs", "LocalIndex", "RemoteIndex", "RemoteAddrs", "Cert", "MessageCounter", "CurrentRemote", "CurrentRelaysToMe", "CurrentRelaysThroughMe"}, thi) assert.EqualValues(t, &expectedInfo, thi) test.AssertDeepCopyEqual(t, &expectedInfo, thi) // Make sure we don't panic if the host info doesn't have a cert yet assert.NotPanics(t, func() { - thi = c.GetHostInfoByVpnIp(vpnIp2, false) + thi = c.GetHostInfoByVpnAddr(vpnIp2, false) }) } diff --git a/control_tester.go b/control_tester.go index fa87e5300..451dac533 100644 --- a/control_tester.go +++ b/control_tester.go @@ -6,8 +6,6 @@ package nebula import ( "net/netip" - "github.com/slackhq/nebula/cert" - "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/slackhq/nebula/header" @@ -51,15 +49,15 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType header.MessageType, // This is necessary if you did not configure static hosts or are not running a lighthouse func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() if toAddr.Addr().Is4() { - remoteList.unlockedPrependV4(vpnIp, NewIp4AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV4(vpnIp, netAddrToProtoV4AddrPort(toAddr.Addr(), toAddr.Port())) } else { - remoteList.unlockedPrependV6(vpnIp, NewIp6AndPortFromNetIP(toAddr.Addr(), toAddr.Port())) + remoteList.unlockedPrependV6(vpnIp, netAddrToProtoV6AddrPort(toAddr.Addr(), toAddr.Port())) } } @@ -67,12 +65,12 @@ func (c *Control) InjectLightHouseAddr(vpnIp netip.Addr, toAddr netip.AddrPort) // This is necessary to inform an initiator of possible relays for communicating with a responder func (c *Control) InjectRelays(vpnIp netip.Addr, relayVpnIps []netip.Addr) { c.f.lightHouse.Lock() - remoteList := c.f.lightHouse.unlockedGetRemoteList(vpnIp) + remoteList := c.f.lightHouse.unlockedGetRemoteList([]netip.Addr{vpnIp}) remoteList.Lock() defer remoteList.Unlock() c.f.lightHouse.Unlock() - remoteList.unlockedSetRelay(vpnIp, vpnIp, relayVpnIps) + remoteList.unlockedSetRelay(vpnIp, relayVpnIps) } // GetFromTun will pull a packet off the tun side of nebula @@ -99,21 +97,42 @@ func (c *Control) InjectUDPPacket(p *udp.Packet) { } // InjectTunUDPPacket puts a udp packet on the tun interface. Using UDP here because it's a simpler protocol -func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort uint16, data []byte) { - //TODO: IPV6-WORK - ip := layers.IPv4{ - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - SrcIP: c.f.inside.Cidr().Addr().Unmap().AsSlice(), - DstIP: toIp.Unmap().AsSlice(), +func (c *Control) InjectTunUDPPacket(toAddr netip.Addr, toPort uint16, fromAddr netip.Addr, fromPort uint16, data []byte) { + serialize := make([]gopacket.SerializableLayer, 0) + var netLayer gopacket.NetworkLayer + if toAddr.Is6() { + if !fromAddr.Is6() { + panic("Cant send ipv6 to ipv4") + } + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip + } else { + if !fromAddr.Is4() { + panic("Cant send ipv4 to ipv6") + } + + ip := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + SrcIP: fromAddr.Unmap().AsSlice(), + DstIP: toAddr.Unmap().AsSlice(), + } + serialize = append(serialize, ip) + netLayer = ip } udp := layers.UDP{ SrcPort: layers.UDPPort(fromPort), DstPort: layers.UDPPort(toPort), } - err := udp.SetNetworkLayerForChecksum(&ip) + err := udp.SetNetworkLayerForChecksum(netLayer) if err != nil { panic(err) } @@ -123,7 +142,9 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui ComputeChecksums: true, FixLengths: true, } - err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data)) + + serialize = append(serialize, &udp, gopacket.Payload(data)) + err = gopacket.SerializeLayers(buffer, opt, serialize...) if err != nil { panic(err) } @@ -131,8 +152,8 @@ func (c *Control) InjectTunUDPPacket(toIp netip.Addr, toPort uint16, fromPort ui c.f.inside.(*overlay.TestTun).Send(buffer.Bytes()) } -func (c *Control) GetVpnIp() netip.Addr { - return c.f.myVpnNet.Addr() +func (c *Control) GetVpnAddrs() []netip.Addr { + return c.f.myVpnAddrs } func (c *Control) GetUDPAddr() netip.AddrPort { @@ -140,7 +161,7 @@ func (c *Control) GetUDPAddr() netip.AddrPort { } func (c *Control) KillPendingTunnel(vpnIp netip.Addr) bool { - hostinfo := c.f.handshakeManager.QueryVpnIp(vpnIp) + hostinfo := c.f.handshakeManager.QueryVpnAddr(vpnIp) if hostinfo == nil { return false } @@ -153,8 +174,8 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() cert.Certificate { - return c.f.pki.GetCertState().Certificate +func (c *Control) GetCertState() *CertState { + return c.f.pki.getCertState() } func (c *Control) ReHandshake(vpnIp netip.Addr) { diff --git a/dns_server.go b/dns_server.go index 750123122..710f6ed85 100644 --- a/dns_server.go +++ b/dns_server.go @@ -8,6 +8,7 @@ import ( "strings" "sync" + "github.com/gaissmai/bart" "github.com/miekg/dns" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -21,24 +22,39 @@ var dnsAddr string type dnsRecords struct { sync.RWMutex - dnsMap map[string]string - hostMap *HostMap + l *logrus.Logger + dnsMap4 map[string]netip.Addr + dnsMap6 map[string]netip.Addr + hostMap *HostMap + myVpnAddrsTable *bart.Table[struct{}] } -func newDnsRecords(hostMap *HostMap) *dnsRecords { +func newDnsRecords(l *logrus.Logger, cs *CertState, hostMap *HostMap) *dnsRecords { return &dnsRecords{ - dnsMap: make(map[string]string), - hostMap: hostMap, + l: l, + dnsMap4: make(map[string]netip.Addr), + dnsMap6: make(map[string]netip.Addr), + hostMap: hostMap, + myVpnAddrsTable: cs.myVpnAddrsTable, } } -func (d *dnsRecords) Query(data string) string { +func (d *dnsRecords) Query(q uint16, data string) netip.Addr { + data = strings.ToLower(data) d.RLock() defer d.RUnlock() - if r, ok := d.dnsMap[strings.ToLower(data)]; ok { - return r + switch q { + case dns.TypeA: + if r, ok := d.dnsMap4[data]; ok { + return r + } + case dns.TypeAAAA: + if r, ok := d.dnsMap6[data]; ok { + return r + } } - return "" + + return netip.Addr{} } func (d *dnsRecords) QueryCert(data string) string { @@ -47,7 +63,7 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - hostinfo := d.hostMap.QueryVpnIp(ip) + hostinfo := d.hostMap.QueryVpnAddr(ip) if hostinfo == nil { return "" } @@ -64,38 +80,62 @@ func (d *dnsRecords) QueryCert(data string) string { return string(b) } -func (d *dnsRecords) Add(host, data string) { +// Add adds the first IPv4 and IPv6 address that appears in `addresses` as the record for `host` +func (d *dnsRecords) Add(host string, addresses []netip.Addr) { + host = strings.ToLower(host) d.Lock() defer d.Unlock() - d.dnsMap[strings.ToLower(host)] = data + haveV4 := false + haveV6 := false + for _, addr := range addresses { + if addr.Is4() && !haveV4 { + d.dnsMap4[host] = addr + haveV4 = true + } else if addr.Is6() && !haveV6 { + d.dnsMap6[host] = addr + haveV6 = true + } + if haveV4 && haveV6 { + break + } + } +} + +func (d *dnsRecords) isSelfNebulaOrLocalhost(addr string) bool { + a, _, _ := net.SplitHostPort(addr) + b, err := netip.ParseAddr(a) + if err != nil { + return false + } + + if b.IsLoopback() { + return true + } + + _, found := d.myVpnAddrsTable.Lookup(b) + return found //if we found it in this table, it's good } -func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { +func (d *dnsRecords) parseQuery(m *dns.Msg, w dns.ResponseWriter) { for _, q := range m.Question { switch q.Qtype { - case dns.TypeA: - l.Debugf("Query for A %s", q.Name) - ip := dnsR.Query(q.Name) - if ip != "" { - rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) + case dns.TypeA, dns.TypeAAAA: + qType := dns.TypeToString[q.Qtype] + d.l.Debugf("Query for %s %s", qType, q.Name) + ip := d.Query(q.Qtype, q.Name) + if ip.IsValid() { + rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", q.Name, qType, ip)) if err == nil { m.Answer = append(m.Answer, rr) } } case dns.TypeTXT: - a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) - b, err := netip.ParseAddr(a) - if err != nil { + // We only answer these queries from nebula nodes or localhost + if !d.isSelfNebulaOrLocalhost(w.RemoteAddr().String()) { return } - - // We don't answer these queries from non nebula nodes or localhost - //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) - if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { - return - } - l.Debugf("Query for TXT %s", q.Name) - ip := dnsR.QueryCert(q.Name) + d.l.Debugf("Query for TXT %s", q.Name) + ip := d.QueryCert(q.Name) if ip != "" { rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) if err == nil { @@ -110,26 +150,24 @@ func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { } } -func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { +func (d *dnsRecords) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) m.Compress = false switch r.Opcode { case dns.OpcodeQuery: - parseQuery(l, m, w) + d.parseQuery(m, w) } w.WriteMsg(m) } -func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { - dnsR = newDnsRecords(hostMap) +func dnsMain(l *logrus.Logger, cs *CertState, hostMap *HostMap, c *config.C) func() { + dnsR = newDnsRecords(l, cs, hostMap) // attach request handler func - dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { - handleDnsRequest(l, w, r) - }) + dns.HandleFunc(".", dnsR.handleDnsRequest) c.RegisterReloadCallback(func(c *config.C) { reloadDns(l, c) diff --git a/dns_server_test.go b/dns_server_test.go index 69f6ae84f..f4643a36a 100644 --- a/dns_server_test.go +++ b/dns_server_test.go @@ -1,23 +1,38 @@ package nebula import ( + "net/netip" "testing" "github.com/miekg/dns" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/stretchr/testify/assert" ) func TestParsequery(t *testing.T) { - //TODO: This test is basically pointless + l := logrus.New() hostMap := &HostMap{} - ds := newDnsRecords(hostMap) - ds.Add("test.com.com", "1.2.3.4") + ds := newDnsRecords(l, &CertState{}, hostMap) + addrs := []netip.Addr{ + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.5"), + netip.MustParseAddr("fd01::24"), + netip.MustParseAddr("fd01::25"), + } + ds.Add("test.com.com", addrs) - m := new(dns.Msg) + m := &dns.Msg{} m.SetQuestion("test.com.com", dns.TypeA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "1.2.3.4", m.Answer[0].(*dns.A).A.String()) - //parseQuery(m) + m = &dns.Msg{} + m.SetQuestion("test.com.com", dns.TypeAAAA) + ds.parseQuery(m, nil) + assert.NotNil(t, m.Answer) + assert.Equal(t, "fd01::24", m.Answer[0].(*dns.AAAA).AAAA.String()) } func Test_getDnsServerAddr(t *testing.T) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index f6069bf46..aab5445ea 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -4,14 +4,17 @@ package e2e import ( - "fmt" "net/netip" "slices" "testing" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/e2e/router" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" @@ -20,12 +23,12 @@ import ( ) func BenchmarkHotPath(b *testing.B) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, _, _, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() @@ -35,7 +38,7 @@ func BenchmarkHotPath(b *testing.B) { r.CancelFlowLogs() for n := 0; n < b.N; n++ { - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) _ = r.RouteForAllUntilTxTun(theirControl) } @@ -44,19 +47,19 @@ func BenchmarkHotPath(b *testing.B) { } func TestGoodHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Start the servers myControl.Start() theirControl.Start() t.Log("Send a udp packet through to begin standing up the tunnel, this should come out the other side") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) t.Log("Have them consume my stage 0 packet. They have a tunnel now") theirControl.InjectUDPPacket(myControl.GetFromUDP(true)) @@ -77,16 +80,16 @@ func TestGoodHandshake(t *testing.T) { myControl.WaitForType(1, 0, theirControl) t.Log("Make sure our host infos are correct") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) t.Log("Get that cached packet and make sure it looks right") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Do a bidirectional tunnel test") r := router.NewR(t, myControl, theirControl) defer r.RenderFlow() - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Final hostmaps", myControl, theirControl) myControl.Stop() @@ -95,14 +98,14 @@ func TestGoodHandshake(t *testing.T) { } func TestWrongResponderHandshake(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) // Put the evil udp addr in for their vpn Ip, this is a case of being lied to by the lighthouse. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -114,7 +117,7 @@ func TestWrongResponderHandshake(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -131,8 +134,8 @@ func TestWrongResponderHandshake(t *testing.T) { }) t.Log("Evil tunnel is closed, inject the correct udp addr for them") - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) t.Log("Route until we see the cached packet") @@ -153,18 +156,18 @@ func TestWrongResponderHandshake(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) @@ -174,19 +177,19 @@ func TestWrongResponderHandshake(t *testing.T) { } func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.99/24", nil) - evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(ca, caKey, "evil", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.99/24", nil) + evilControl, evilVpnIp, evilUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "evil", "10.128.0.2/24", nil) o := m{ "static_host_map": m{ - theirVpnIpNet.Addr().String(): []string{evilUdpAddr.String()}, + theirVpnIpNet[0].Addr().String(): []string{evilUdpAddr.String()}, }, } - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.100/24", o) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.100/24", o) // Put the evil udp addr in for their vpn addr, this is a case of a remote at a static entry changing its vpn addr. - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), evilUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), evilUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl, evilControl) @@ -198,7 +201,7 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { evilControl.Start() t.Log("Start the handshake process, we will route until we see the evil tunnel closed") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) h := &header.H{} r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { @@ -215,8 +218,8 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { }) t.Log("Evil tunnel is closed, inject the correct udp addr for them") - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - pendingHi := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), true) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + pendingHi := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), true) assert.NotContains(t, pendingHi.RemoteAddrs, evilUdpAddr) t.Log("Route until we see the cached packet") @@ -237,18 +240,19 @@ func TestWrongResponderHandshakeStaticHostMap(t *testing.T) { t.Log("My cached packet should be received by them") myCachedPacket := theirControl.GetFromTun(true) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) t.Log("Test the tunnel with them") - assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl) - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIpNet, theirVpnIpNet, myControl, theirControl) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Flush all packets from all controllers") r.FlushAll() t.Log("Ensure ensure I don't have any hostinfo artifacts from evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), true), "My pending hostmap should not contain evil") - assert.Nil(t, myControl.GetHostInfoByVpnIp(evilVpnIp.Addr(), false), "My main hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), true), "My pending hostmap should not contain evil") + assert.Nil(t, myControl.GetHostInfoByVpnAddr(evilVpnIp[0].Addr(), false), "My main hostmap should not contain evil") + //NOTE: if evil lost the handshake race it may still have a tunnel since me would reject the handshake since the tunnel is complete //TODO: assert hostmaps for everyone r.RenderHostmaps("Final hostmaps", myControl, theirControl, evilControl) @@ -261,13 +265,13 @@ func TestStage1Race(t *testing.T) { // This tests ensures that two hosts handshaking with each other at the same time will allow traffic to flow // But will eventually collapse down to a single tunnel - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -278,8 +282,8 @@ func TestStage1Race(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake to start on both me and them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1 handshake packets") myHsForThem := myControl.GetFromUDP(true) @@ -291,14 +295,14 @@ func TestStage1Race(t *testing.T) { r.Log("Route until they receive a message packet") myCachedPacket := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Their cached packet should be received by me") theirCachedPacket := r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them"), theirCachedPacket, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.Log("Do a bidirectional tunnel test") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myHostmapHosts := myControl.ListHostmapHosts(false) myHostmapIndexes := myControl.ListHostmapIndexes(false) @@ -316,7 +320,7 @@ func TestStage1Race(t *testing.T) { r.Log("Spin until connection manager tears down a tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } @@ -338,13 +342,13 @@ func TestStage1Race(t *testing.T) { } func TestUncleanShutdownRaceLoser(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -355,10 +359,10 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Nuke my hostmap") myHostmap := myControl.GetHostmap() @@ -366,17 +370,17 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { myHostmap.Indexes = map[uint32]*nebula.HostInfo{} myHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me again")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me again")) p = r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me again"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(theirControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(theirControl.GetHostmap().Indexes) < start { break } @@ -387,13 +391,13 @@ func TestUncleanShutdownRaceLoser(t *testing.T) { } func TestUncleanShutdownRaceWinner(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -404,10 +408,10 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirControl.Start() r.Log("Trigger a handshake from me to them") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, theirControl) r.Log("Nuke my hostmap") @@ -416,18 +420,18 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { theirHostmap.Indexes = map[uint32]*nebula.HostInfo{} theirHostmap.RemoteIndexes = map[uint32]*nebula.HostInfo{} - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them again")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them again")) p = r.RouteForAllUntilTxTun(myControl) - assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from them again"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Derp hostmaps", myControl, theirControl) r.Log("Assert the tunnel works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.Log("Wait for the dead index to go away") start := len(myControl.GetHostmap().Indexes) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) if len(myControl.GetHostmap().Indexes) < start { break } @@ -438,15 +442,15 @@ func TestUncleanShutdownRaceWinner(t *testing.T) { } func TestRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -458,31 +462,162 @@ func TestRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) //TODO: assert we actually used the relay even though it should be impossible for a tunnel to have occurred without it } +func TestReestablishRelays(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + + // Teach my how to get to the relay and that their can be reached via the relay + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, myControl, relayControl, theirControl) + defer r.RenderFlow() + + // Start the servers + myControl.Start() + relayControl.Start() + theirControl.Start() + + t.Log("Trigger a handshake from me to them via the relay") + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p := r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) + + t.Log("Ensure packet traversal from them to me via the relay") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + assertUdpPacket(t, []byte("Hi from them"), p, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), 80, 80) + + // If we break the relay's connection to 'them', 'me' needs to detect and recover the connection + r.Log("Close the tunnel") + relayControl.CloseTunnel(theirVpnIpNet[0].Addr(), true) + + start := len(myControl.GetHostmap().Indexes) + curIndexes := len(myControl.GetHostmap().Indexes) + for curIndexes >= start { + curIndexes = len(myControl.GetHostmap().Indexes) + r.Logf("Wait for the dead index to go away:start=%v indexes, currnet=%v indexes", start, curIndexes) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me should fail")) + + r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType { + return router.RouteAndExit + }) + time.Sleep(2 * time.Second) + } + r.Log("Dead index went away. Woot!") + r.RenderHostmaps("Me removed hostinfo", myControl, relayControl, theirControl) + // Next packet should re-establish a relayed connection and work just great. + + t.Logf("Assert the tunnel...") + for { + t.Log("RouteForAllUntilTxTun") + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + + p = r.RouteForAllUntilTxTun(theirControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.SrcIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.DstIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("DstIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from me" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + t.Log("Assert the tunnel works the other way, too") + for { + t.Log("RouteForAllUntilTxTun") + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) + + p = r.RouteForAllUntilTxTun(myControl) + r.Log("Assert the tunnel works") + packet := gopacket.NewPacket(p, layers.LayerTypeIPv4, gopacket.Lazy) + v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if slices.Compare(v4.DstIP, myVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("Dst is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + if slices.Compare(v4.SrcIP, theirVpnIpNet[0].Addr().AsSlice()) != 0 { + t.Logf("SrcIP is unexpected...this is not the packet I'm looking for. Keep looking") + continue + } + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udp == nil { + t.Log("Not a UDP packet. This is not the packet I'm looking for. Keep looking") + continue + } + data := packet.ApplicationLayer() + if data == nil { + t.Log("No data found in packet. This is not the packet I'm looking for. Keep looking.") + continue + } + if string(data.Payload()) != "Hi from them" { + t.Logf("Unexpected payload: '%v', keep looking", string(data.Payload())) + continue + } + t.Log("I found my lost packet. I am so happy.") + break + } + r.RenderHostmaps("Final hostmaps", myControl, relayControl, theirControl) + +} + func TestStage1RaceRelays(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -494,14 +629,14 @@ func TestStage1RaceRelays(t *testing.T) { theirControl.Start() r.Log("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) r.Log("Wait for a packet from them to me") p := r.RouteForAllUntilTxTun(myControl) @@ -518,21 +653,21 @@ func TestStage1RaceRelays(t *testing.T) { func TestStage1RaceRelays2(t *testing.T) { //NOTE: this is a race between me and relay resulting in a full tunnel from me to them via relay - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) l := NewTestLogger() // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - theirControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + theirControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - theirControl.InjectRelays(myVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + theirControl.InjectRelays(myVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - relayControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + relayControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -545,16 +680,16 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Get a tunnel between me and relay") l.Info("Get a tunnel between me and relay") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) r.Log("Get a tunnel between them and relay") l.Info("Get a tunnel between them and relay") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) r.Log("Trigger a handshake from both them and me via relay to them and me") l.Info("Trigger a handshake from both them and me via relay to them and me") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) //r.RouteUntilAfterMsgType(myControl, header.Control, header.MessageNone) //r.RouteUntilAfterMsgType(theirControl, header.Control, header.MessageNone) @@ -567,7 +702,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) t.Log("Wait until we remove extra tunnels") l.Info("Wait until we remove extra tunnels") @@ -587,7 +722,7 @@ func TestStage1RaceRelays2(t *testing.T) { "theirControl": len(theirControl.GetHostmap().Indexes), "relayControl": len(relayControl.GetHostmap().Indexes), }).Info("Waiting for hostinfos to be removed...") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) retries-- @@ -595,7 +730,7 @@ func TestStage1RaceRelays2(t *testing.T) { r.Log("Assert the tunnel works") l.Info("Assert the tunnel works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) myControl.Stop() theirControl.Stop() @@ -606,15 +741,15 @@ func TestStage1RaceRelays2(t *testing.T) { } func TestRehandshakingRelays(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.1/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.128/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -626,17 +761,17 @@ func TestRehandshakingRelays(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -654,8 +789,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -667,8 +802,8 @@ func TestRehandshakingRelays(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -679,13 +814,13 @@ func TestRehandshakingRelays(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -693,7 +828,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -701,7 +836,7 @@ func TestRehandshakingRelays(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -710,15 +845,15 @@ func TestRehandshakingRelays(t *testing.T) { func TestRehandshakingRelaysPrimary(t *testing.T) { // This test is the same as TestRehandshakingRelays but one of the terminal types is a primary swap winner - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, _, _ := newSimpleServer(ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) - relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.128/24", m{"relay": m{"use_relays": true}}) + relayControl, relayVpnIpNet, relayUdpAddr, relayConfig := newSimpleServer(cert.Version1, ca, caKey, "relay ", "10.128.0.1/24", m{"relay": m{"am_relay": true}}) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them ", "10.128.0.2/24", m{"relay": m{"use_relays": true}}) // Teach my how to get to the relay and that their can be reached via the relay - myControl.InjectLightHouseAddr(relayVpnIpNet.Addr(), relayUdpAddr) - myControl.InjectRelays(theirVpnIpNet.Addr(), []netip.Addr{relayVpnIpNet.Addr()}) - relayControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) + myControl.InjectLightHouseAddr(relayVpnIpNet[0].Addr(), relayUdpAddr) + myControl.InjectRelays(theirVpnIpNet[0].Addr(), []netip.Addr{relayVpnIpNet[0].Addr()}) + relayControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, relayControl, theirControl) @@ -730,17 +865,17 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { theirControl.Start() t.Log("Trigger a handshake from me to them via the relay") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) p := r.RouteForAllUntilTxTun(theirControl) r.Log("Assert the tunnel works") - assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), 80, 80) + assertUdpPacket(t, []byte("Hi from me"), p, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), 80, 80) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -758,8 +893,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") - assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) - c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), myControl, relayControl, r) + c := myControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") @@ -771,8 +906,8 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") - assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) - c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) + assertTunnel(t, theirVpnIpNet[0].Addr(), relayVpnIpNet[0].Addr(), theirControl, relayControl, r) + c := theirControl.GetHostInfoByVpnAddr(relayVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") @@ -783,13 +918,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.RenderHostmaps("working hostmaps", myControl, relayControl, theirControl) // We should have two hostinfos on all sides for len(myControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for myControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(myControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -797,7 +932,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(theirControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for theirControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(theirControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -805,7 +940,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { for len(relayControl.GetHostmap().Indexes) != 2 { t.Logf("Waiting for relayControl hostinfos (%v != 2) to get cleaned up from lack of use...", len(relayControl.GetHostmap().Indexes)) r.Log("Assert the relay tunnel still works") - assertTunnel(t, theirVpnIpNet.Addr(), myVpnIpNet.Addr(), theirControl, myControl, r) + assertTunnel(t, theirVpnIpNet[0].Addr(), myVpnIpNet[0].Addr(), theirControl, myControl, r) r.Log("yupitdoes") time.Sleep(time.Second) } @@ -813,13 +948,13 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { } func TestRehandshaking(t *testing.T) { - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -830,12 +965,12 @@ func TestRehandshaking(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -852,8 +987,8 @@ func TestRehandshaking(t *testing.T) { myConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) if len(c.Cert.Groups()) != 0 { // We have a new certificate now break @@ -880,19 +1015,19 @@ func TestRehandshaking(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) + c := theirControl.GetHostInfoByVpnAddr(myVpnIpNet[0].Addr(), false) assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides @@ -910,13 +1045,13 @@ func TestRehandshaking(t *testing.T) { func TestRehandshakingLoser(t *testing.T) { // The purpose of this test is that the race loser renews their certificate and rehandshakes. The final tunnel // Should be the one with the new certificate - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(ca, caKey, "me ", "10.128.0.2/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(ca, caKey, "them", "10.128.0.1/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, myConfig := newSimpleServer(cert.Version1, ca, caKey, "me ", "10.128.0.2/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, theirConfig := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.1/24", nil) // Put their info in our lighthouse and vice versa - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Build a router so we don't have to reason who gets which packet r := router.NewR(t, myControl, theirControl) @@ -927,16 +1062,12 @@ func TestRehandshakingLoser(t *testing.T) { theirControl.Start() t.Log("Stand up a tunnel between me and them") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - - tt1 := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - tt2 := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - fmt.Println(tt1.LocalIndex, tt2.LocalIndex) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := cert_test.NewTestCert(cert.Version1, cert.Curve_CURVE25519, ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) caB, err := ca.MarshalPEM() if err != nil { @@ -953,8 +1084,8 @@ func TestRehandshakingLoser(t *testing.T) { theirConfig.ReloadConfigString(string(rc)) for { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break @@ -980,19 +1111,19 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Spin until there is only 1 tunnel") for len(myControl.GetHostmap().Indexes)+len(theirControl.GetHostmap().Indexes) > 2 { - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) t.Log("Connection manager hasn't ticked yet") time.Sleep(time.Second) } - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) myFinalHostmapHosts := myControl.ListHostmapHosts(false) myFinalHostmapIndexes := myControl.ListHostmapIndexes(false) theirFinalHostmapHosts := theirControl.ListHostmapHosts(false) theirFinalHostmapIndexes := theirControl.ListHostmapIndexes(false) // Make sure the correct tunnel won - theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) + theirCertInMe := myControl.GetHostInfoByVpnAddr(theirVpnIpNet[0].Addr(), false) assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides @@ -1010,13 +1141,13 @@ func TestRaceRegression(t *testing.T) { // This test forces stage 1, stage 2, stage 1 to be received by me from them // We had a bug where we were not finding the duplicate handshake and responding to the final stage 1 which // caused a cross-linked hostinfo - ca, _, caKey, _ := NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(ca, caKey, "me", "10.128.0.1/24", nil) - theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(ca, caKey, "them", "10.128.0.2/24", nil) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version1, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + myControl, myVpnIpNet, myUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "me", "10.128.0.1/24", nil) + theirControl, theirVpnIpNet, theirUdpAddr, _ := newSimpleServer(cert.Version1, ca, caKey, "them", "10.128.0.2/24", nil) // Put their info in our lighthouse - myControl.InjectLightHouseAddr(theirVpnIpNet.Addr(), theirUdpAddr) - theirControl.InjectLightHouseAddr(myVpnIpNet.Addr(), myUdpAddr) + myControl.InjectLightHouseAddr(theirVpnIpNet[0].Addr(), theirUdpAddr) + theirControl.InjectLightHouseAddr(myVpnIpNet[0].Addr(), myUdpAddr) // Start the servers myControl.Start() @@ -1030,8 +1161,8 @@ func TestRaceRegression(t *testing.T) { //them rx stage:2 initiatorIndex=120607833 responderIndex=4209862089 t.Log("Start both handshakes") - myControl.InjectTunUDPPacket(theirVpnIpNet.Addr(), 80, 80, []byte("Hi from me")) - theirControl.InjectTunUDPPacket(myVpnIpNet.Addr(), 80, 80, []byte("Hi from them")) + myControl.InjectTunUDPPacket(theirVpnIpNet[0].Addr(), 80, myVpnIpNet[0].Addr(), 80, []byte("Hi from me")) + theirControl.InjectTunUDPPacket(myVpnIpNet[0].Addr(), 80, theirVpnIpNet[0].Addr(), 80, []byte("Hi from them")) t.Log("Get both stage 1") myStage1ForThem := myControl.GetFromUDP(true) @@ -1061,8 +1192,48 @@ func TestRaceRegression(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) t.Log("Make sure the tunnel still works") - assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) + assertTunnel(t, myVpnIpNet[0].Addr(), theirVpnIpNet[0].Addr(), myControl, theirControl, r) + + myControl.Stop() + theirControl.Stop() +} + +func TestV2NonPrimaryWithLighthouse(t *testing.T) { + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + lhControl, lhVpnIpNet, lhUdpAddr, _ := newSimpleServer(cert.Version2, ca, caKey, "lh ", "10.128.0.1/24, ff::1/64", m{"lighthouse": m{"am_lighthouse": true}}) + + o := m{ + "static_host_map": m{ + lhVpnIpNet[1].Addr().String(): []string{lhUdpAddr.String()}, + }, + "lighthouse": m{ + "hosts": []string{lhVpnIpNet[1].Addr().String()}, + "local_allow_list": m{ + // Try and block our lighthouse updates from using the actual addresses assigned to this computer + // If we start discovering addresses the test router doesn't know about then test traffic cant flow + "10.0.0.0/24": true, + "::/0": false, + }, + }, + } + myControl, myVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "me ", "10.128.0.2/24, ff::2/64", o) + theirControl, theirVpnIpNet, _, _ := newSimpleServer(cert.Version2, ca, caKey, "them", "10.128.0.3/24, ff::3/64", o) + + // Build a router so we don't have to reason who gets which packet + r := router.NewR(t, lhControl, myControl, theirControl) + defer r.RenderFlow() + + // Start the servers + lhControl.Start() + myControl.Start() + theirControl.Start() + + t.Log("Stand up an ipv6 tunnel between me and them") + assert.True(t, myVpnIpNet[1].Addr().Is6()) + assert.True(t, theirVpnIpNet[1].Addr().Is6()) + assertTunnel(t, myVpnIpNet[1].Addr(), theirVpnIpNet[1].Addr(), myControl, theirControl, r) + lhControl.Stop() myControl.Stop() theirControl.Stop() } diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 77996f3da..f5df2fb25 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -8,6 +8,7 @@ import ( "io" "net/netip" "os" + "strings" "testing" "time" @@ -17,6 +18,7 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e/router" "github.com/stretchr/testify/assert" @@ -26,25 +28,35 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(v cert.Version, caCrt cert.Certificate, caKey []byte, name string, sVpnNetworks string, overrides m) (*nebula.Control, []netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() - vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) - if err != nil { - panic(err) + var vpnNetworks []netip.Prefix + for _, sn := range strings.Split(sVpnNetworks, ",") { + vpnIpNet, err := netip.ParsePrefix(strings.TrimSpace(sn)) + if err != nil { + panic(err) + } + vpnNetworks = append(vpnNetworks, vpnIpNet) + } + + if len(vpnNetworks) == 0 { + panic("no vpn networks") } var udpAddr netip.AddrPort - if vpnIpNet.Addr().Is4() { - budpIp := vpnIpNet.Addr().As4() + if vpnNetworks[0].Addr().Is4() { + budpIp := vpnNetworks[0].Addr().As4() budpIp[1] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom4(budpIp), 4242) } else { - budpIp := vpnIpNet.Addr().As16() - budpIp[13] -= 128 + budpIp := vpnNetworks[0].Addr().As16() + // beef for funsies + budpIp[2] = 190 + budpIp[3] = 239 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(v, cert.Curve_CURVE25519, caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnNetworks, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { @@ -88,11 +100,16 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe } if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + final := m{} + err = mergo.Merge(&final, overrides, mergo.WithAppendSlice) if err != nil { panic(err) } - mc = overrides + err = mergo.Merge(&final, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = final } cb, err := yaml.Marshal(mc) @@ -109,7 +126,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe panic(err) } - return control, vpnIpNet, udpAddr, c + return control, vpnNetworks, udpAddr, c } type doneCb func() @@ -132,27 +149,28 @@ func deadline(t *testing.T, seconds time.Duration) doneCb { func assertTunnel(t *testing.T, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control, r *router.R) { // Send a packet from them to me - controlB.InjectTunUDPPacket(vpnIpA, 80, 90, []byte("Hi from B")) + controlB.InjectTunUDPPacket(vpnIpA, 80, vpnIpB, 90, []byte("Hi from B")) bPacket := r.RouteForAllUntilTxTun(controlA) assertUdpPacket(t, []byte("Hi from B"), bPacket, vpnIpB, vpnIpA, 90, 80) // And once more from me to them - controlA.InjectTunUDPPacket(vpnIpB, 80, 90, []byte("Hello from A")) + controlA.InjectTunUDPPacket(vpnIpB, 80, vpnIpA, 90, []byte("Hello from A")) aPacket := r.RouteForAllUntilTxTun(controlB) assertUdpPacket(t, []byte("Hello from A"), aPacket, vpnIpA, vpnIpB, 90, 80) } -func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIpB netip.Addr, controlA, controlB *nebula.Control) { +func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnNetsA, vpnNetsB []netip.Prefix, controlA, controlB *nebula.Control) { // Get both host infos - hBinA := controlA.GetHostInfoByVpnIp(vpnIpB, false) - assert.NotNil(t, hBinA, "Host B was not found by vpnIp in controlA") + //TODO: we may want to loop over each vpnAddr and assert all the things + hBinA := controlA.GetHostInfoByVpnAddr(vpnNetsB[0].Addr(), false) + assert.NotNil(t, hBinA, "Host B was not found by vpnAddr in controlA") - hAinB := controlB.GetHostInfoByVpnIp(vpnIpA, false) - assert.NotNil(t, hAinB, "Host A was not found by vpnIp in controlB") + hAinB := controlB.GetHostInfoByVpnAddr(vpnNetsA[0].Addr(), false) + assert.NotNil(t, hAinB, "Host A was not found by vpnAddr in controlB") // Check that both vpn and real addr are correct - assert.Equal(t, vpnIpB, hBinA.VpnIp, "Host B VpnIp is wrong in control A") - assert.Equal(t, vpnIpA, hAinB.VpnIp, "Host A VpnIp is wrong in control B") + assert.EqualValues(t, getAddrs(vpnNetsB), hBinA.VpnAddrs, "Host B VpnIp is wrong in control A") + assert.EqualValues(t, getAddrs(vpnNetsA), hAinB.VpnAddrs, "Host A VpnIp is wrong in control B") assert.Equal(t, addrB, hBinA.CurrentRemote, "Host B remote is wrong in control A") assert.Equal(t, addrA, hAinB.CurrentRemote, "Host A remote is wrong in control B") @@ -179,6 +197,33 @@ func assertHostInfoPair(t *testing.T, addrA, addrB netip.AddrPort, vpnIpA, vpnIp } func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + if toIp.Is6() { + assertUdpPacket6(t, expected, b, fromIp, toIp, fromPort, toPort) + } else { + assertUdpPacket4(t, expected, b, fromIp, toIp, fromPort, toPort) + } +} + +func assertUdpPacket6(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { + packet := gopacket.NewPacket(b, layers.LayerTypeIPv6, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + assert.NotNil(t, v6, "No ipv6 data found") + + assert.Equal(t, fromIp.AsSlice(), []byte(v6.SrcIP), "Source ip was incorrect") + assert.Equal(t, toIp.AsSlice(), []byte(v6.DstIP), "Dest ip was incorrect") + + udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + assert.NotNil(t, udp, "No udp data found") + + assert.Equal(t, fromPort, uint16(udp.SrcPort), "Source port was incorrect") + assert.Equal(t, toPort, uint16(udp.DstPort), "Dest port was incorrect") + + data := packet.ApplicationLayer() + assert.NotNil(t, data) + assert.Equal(t, expected, data.Payload(), "Data was incorrect") +} + +func assertUdpPacket4(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, fromPort, toPort uint16) { packet := gopacket.NewPacket(b, layers.LayerTypeIPv4, gopacket.Lazy) v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) assert.NotNil(t, v4, "No ipv4 data found") @@ -197,6 +242,14 @@ func assertUdpPacket(t *testing.T, expected, b []byte, fromIp, toIp netip.Addr, assert.Equal(t, expected, data.Payload(), "Data was incorrect") } +func getAddrs(ns []netip.Prefix) []netip.Addr { + var a []netip.Addr + for _, n := range ns { + a = append(a, n.Addr()) + } + return a +} + func NewTestLogger() *logrus.Logger { l := logrus.New() diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index 29fa95991..f2805d0c8 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -58,8 +58,9 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Name(), " ") - clusterVpnIp := c.GetCert().Networks()[0].Addr() + crt := c.GetCertState().GetDefaultCertificate() + clusterName := strings.Trim(crt.Name(), " ") + clusterVpnIp := crt.Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -101,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { for _, idx := range indexes { hi, ok := hm.Indexes[idx] if ok { - r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) + r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnAddrs()) remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi diff --git a/e2e/router/router.go b/e2e/router/router.go index 08905705c..5e52ed77c 100644 --- a/e2e/router/router.go +++ b/e2e/router/router.go @@ -10,8 +10,8 @@ import ( "os" "path/filepath" "reflect" + "regexp" "sort" - "strings" "sync" "testing" "time" @@ -136,7 +136,10 @@ func NewR(t testing.TB, controls ...*nebula.Control) *R { panic("Duplicate listen address: " + addr.String()) } - r.vpnControls[c.GetVpnIp()] = c + for _, vpnAddr := range c.GetVpnAddrs() { + r.vpnControls[vpnAddr] = c + } + r.controls[addr] = c } @@ -213,11 +216,11 @@ func (r *R) renderFlow() { continue } participants[addr] = struct{}{} - sanAddr := strings.Replace(addr.String(), ":", "-", 1) + sanAddr := normalizeName(addr.String()) participantsVals = append(participantsVals, sanAddr) fmt.Fprintf( f, " participant %s as Nebula: %s
UDP: %s\n", - sanAddr, e.packet.from.GetVpnIp(), sanAddr, + sanAddr, e.packet.from.GetVpnAddrs(), sanAddr, ) } @@ -250,9 +253,9 @@ func (r *R) renderFlow() { fmt.Fprintf(f, " %s%s%s: %s(%s), index %v, counter: %v\n", - strings.Replace(p.from.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.from.GetUDPAddr().String()), line, - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), + normalizeName(p.to.GetUDPAddr().String()), h.TypeName(), h.SubTypeName(), h.RemoteIndex, h.MessageCounter, ) } @@ -267,6 +270,11 @@ func (r *R) renderFlow() { } } +func normalizeName(s string) string { + rx := regexp.MustCompile("[\\[\\]\\:]") + return rx.ReplaceAllLiteralString(s, "_") +} + // IgnoreFlow tells the router to stop recording future flows that matches the provided criteria. // messageType and subType will target nebula underlay packets while tun will target nebula overlay packets // NOTE: This is a very broad system, if you set tun to true then no more tun traffic will be rendered @@ -303,7 +311,7 @@ func (r *R) RenderHostmaps(title string, controls ...*nebula.Control) { func (r *R) renderHostmaps(title string) { c := maps.Values(r.controls) sort.SliceStable(c, func(i, j int) bool { - return c[i].GetVpnIp().Compare(c[j].GetVpnIp()) > 0 + return c[i].GetVpnAddrs()[0].Compare(c[j].GetVpnAddrs()[0]) > 0 }) s := renderHostmaps(c...) @@ -419,10 +427,11 @@ func (r *R) RouteUntilTxTun(sender *nebula.Control, receiver *nebula.Control) [] // Nope, lets push the sender along case p := <-udpTx: r.Lock() - c := r.getControl(sender.GetUDPAddr(), p.To, p) + a := sender.GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic("No control for udp tx " + a.String()) } fp := r.unlockedInjectFlow(sender, c, p, false) c.InjectUDPPacket(p) @@ -475,10 +484,11 @@ func (r *R) RouteForAllUntilTxTun(receiver *nebula.Control) []byte { } else { // we are a udp tx, route and continue p := rx.Interface().(*udp.Packet) - c := r.getControl(cm[x].GetUDPAddr(), p.To, p) + a := cm[x].GetUDPAddr() + c := r.getControl(a, p.To, p) if c == nil { r.Unlock() - panic("No control for udp tx") + panic(fmt.Sprintf("No control for udp tx %s", p.To)) } fp := r.unlockedInjectFlow(cm[x], c, p, false) c.InjectUDPPacket(p) @@ -711,30 +721,42 @@ func (r *R) getControl(fromAddr, toAddr netip.AddrPort, p *udp.Packet) *nebula.C } func (r *R) formatUdpPacket(p *packet) string { - packet := gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) - v4 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) - if v4 == nil { - panic("not an ipv4 packet") + var packet gopacket.Packet + var srcAddr netip.Addr + + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv6, gopacket.Lazy) + if packet.ErrorLayer() == nil { + v6 := packet.Layer(layers.LayerTypeIPv6).(*layers.IPv6) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) + } else { + packet = gopacket.NewPacket(p.packet.Data, layers.LayerTypeIPv4, gopacket.Lazy) + v6 := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + if v6 == nil { + panic("not an ipv6 packet") + } + srcAddr, _ = netip.AddrFromSlice(v6.SrcIP) } from := "unknown" - srcAddr, _ := netip.AddrFromSlice(v4.SrcIP) if c, ok := r.vpnControls[srcAddr]; ok { from = c.GetUDPAddr().String() } - udp := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) - if udp == nil { + udpLayer := packet.Layer(layers.LayerTypeUDP).(*layers.UDP) + if udpLayer == nil { panic("not a udp packet") } data := packet.ApplicationLayer() return fmt.Sprintf( " %s-->>%s: src port: %v
dest port: %v
data: \"%v\"\n", - strings.Replace(from, ":", "-", 1), - strings.Replace(p.to.GetUDPAddr().String(), ":", "-", 1), - udp.SrcPort, - udp.DstPort, + normalizeName(from), + normalizeName(p.to.GetUDPAddr().String()), + udpLayer.SrcPort, + udpLayer.DstPort, string(data.Payload()), ) } diff --git a/examples/config.yml b/examples/config.yml index c74ffc68f..1a312838b 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -13,6 +13,12 @@ pki: # disconnect_invalid is a toggle to force a client to be disconnected if the certificate is expired or invalid. #disconnect_invalid: true + # default_version controls which certificate version is used in handshakes. + # This setting only applies if both a v1 and a v2 certificate are configured, in which case it will default to `1`. + # Once all hosts in the mesh are configured with both a v1 and v2 certificate then this should be changed to `2`. + # After all hosts in the mesh are using a v2 certificate then v1 certificates are no longer needed. + # default_version: 1 + # The static host map defines a set of hosts with fixed IP addresses on the internet (or any network). # A host can have multiple fixed IP addresses defined here, and nebula will try each when establishing a tunnel. # The syntax is: @@ -336,10 +342,13 @@ firewall: # host: `any` or a literal hostname, ie `test-host` # group: `any` or a literal group name, ie `default-group` # groups: Same as group but accepts a list of values. Multiple values are AND'd together and a certificate would have to contain all groups to pass - # cidr: a remote CIDR, `0.0.0.0/0` is any. - # local_cidr: a local CIDR, `0.0.0.0/0` is any. This could be used to filter destinations when using unsafe_routes. - # Default is `any` unless the certificate contains subnets and then the default is the ip issued in the certificate - # if `default_local_cidr_any` is false, otherwise its `any`. + # cidr: a remote CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. //TODO: we have a problem, firewall needs to understand this and should probably allow `any` for both + # local_cidr: a local CIDR, `0.0.0.0/0` is any ipv4 and `::/0` is any ipv6. This could be used to filter destinations when using unsafe_routes. + # //TODO: probably should have an `any` that covers both ip versions + # If no unsafe networks are present in the certificate(s) or `default_local_cidr_any` is true then the default is any ipv4 or ipv6 network. + # Otherwise the default is any vpn network assigned to via the certificate. + # `default_local_cidr_any` defaults to false and is deprecated, it will be removed in a future release. + # If there are unsafe routes present its best to set `local_cidr` to whatever best fits the situation. # ca_name: An issuing CA name # ca_sha: An issuing CA shasum diff --git a/firewall.go b/firewall.go index 80a828057..0aae7d6cf 100644 --- a/firewall.go +++ b/firewall.go @@ -8,6 +8,7 @@ import ( "hash/fnv" "net/netip" "reflect" + "slices" "strconv" "strings" "sync" @@ -22,7 +23,7 @@ import ( ) type FirewallInterface interface { - AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, ip, localIp netip.Prefix, caName string, caSha string) error + AddRule(incoming bool, proto uint8, startPort int32, endPort int32, groups []string, host string, addr, localAddr netip.Prefix, caName string, caSha string) error } type conn struct { @@ -51,9 +52,12 @@ type Firewall struct { UDPTimeout time.Duration //linux: 180s max DefaultTimeout time.Duration //linux: 600s - // Used to ensure we don't emit local packets for ips we don't own - localIps *bart.Table[struct{}] - assignedCIDR netip.Prefix + // routableNetworks describes the vpn addresses as well as any unsafe networks issued to us in the certificate. + // The vpn addresses are a full bit match while the unsafe networks only match the prefix + routableNetworks *bart.Table[struct{}] + + // assignedNetworks is a list of vpn networks assigned to us in the certificate. + assignedNetworks []netip.Prefix hasUnsafeNetworks bool rules string @@ -67,9 +71,9 @@ type Firewall struct { } type firewallMetrics struct { - droppedLocalIP metrics.Counter - droppedRemoteIP metrics.Counter - droppedNoRule metrics.Counter + droppedLocalAddr metrics.Counter + droppedRemoteAddr metrics.Counter + droppedNoRule metrics.Counter } type FirewallConntrack struct { @@ -126,84 +130,87 @@ type firewallLocalCIDR struct { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. +// The certificate provided should be the highest version loaded in memory. func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration - var min, max time.Duration + var tmin, tmax time.Duration if tcpTimeout < UDPTimeout { - min = tcpTimeout - max = UDPTimeout + tmin = tcpTimeout + tmax = UDPTimeout } else { - min = UDPTimeout - max = tcpTimeout + tmin = UDPTimeout + tmax = tcpTimeout } - if defaultTimeout < min { - min = defaultTimeout - } else if defaultTimeout > max { - max = defaultTimeout + if defaultTimeout < tmin { + tmin = defaultTimeout + } else if defaultTimeout > tmax { + tmax = defaultTimeout } - localIps := new(bart.Table[struct{}]) - var assignedCIDR netip.Prefix - var assignedSet bool + routableNetworks := new(bart.Table[struct{}]) + var assignedNetworks []netip.Prefix for _, network := range c.Networks() { nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) - localIps.Insert(nprefix, struct{}{}) - - if !assignedSet { - // Only grabbing the first one in the cert since any more than that currently has undefined behavior - assignedCIDR = nprefix - assignedSet = true - } + routableNetworks.Insert(nprefix, struct{}{}) + assignedNetworks = append(assignedNetworks, network) } hasUnsafeNetworks := false for _, n := range c.UnsafeNetworks() { - localIps.Insert(n, struct{}{}) + routableNetworks.Insert(n, struct{}{}) hasUnsafeNetworks = true } return &Firewall{ Conntrack: &FirewallConntrack{ Conns: make(map[firewall.Packet]*conn), - TimerWheel: NewTimerWheel[firewall.Packet](min, max), + TimerWheel: NewTimerWheel[firewall.Packet](tmin, tmax), }, InRules: newFirewallTable(), OutRules: newFirewallTable(), TCPTimeout: tcpTimeout, UDPTimeout: UDPTimeout, DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, + routableNetworks: routableNetworks, + assignedNetworks: assignedNetworks, hasUnsafeNetworks: hasUnsafeNetworks, l: l, incomingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.incoming.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.incoming.dropped.no_rule", nil), }, outgoingMetrics: firewallMetrics{ - droppedLocalIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_ip", nil), - droppedRemoteIP: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_ip", nil), - droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), + droppedLocalAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.local_addr", nil), + droppedRemoteAddr: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.remote_addr", nil), + droppedNoRule: metrics.GetOrRegisterCounter("firewall.outgoing.dropped.no_rule", nil), }, } } -func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, cs *CertState, c *config.C) (*Firewall, error) { + certificate := cs.getCertificate(cert.Version2) + if certificate == nil { + certificate = cs.getCertificate(cert.Version1) + } + + if certificate == nil { + panic("No certificate available to reconfigure the firewall") + } + fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), c.GetDuration("firewall.conntrack.udp_timeout", time.Minute*3), c.GetDuration("firewall.conntrack.default_timeout", time.Minute*10), - nc, + certificate, //TODO: max_connections ) - //TODO: Flip to false after v1.9 release - fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", true) + fw.defaultLocalCIDRAny = c.GetBool("firewall.default_local_cidr_any", false) inboundAction := c.GetString("firewall.inbound_action", "drop") switch inboundAction { @@ -283,7 +290,7 @@ func (f *Firewall) AddRule(incoming bool, proto uint8, startPort int32, endPort fp = ft.TCP case firewall.ProtoUDP: fp = ft.UDP - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: fp = ft.ICMP case firewall.ProtoAny: fp = ft.AnyProto @@ -424,26 +431,25 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure remote address matches nebula certificate - if remoteCidr := h.remoteCidr; remoteCidr != nil { - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := remoteCidr.Lookup(fp.RemoteIP) + if h.networks != nil { + _, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { - f.metrics(incoming).droppedRemoteIP.Inc(1) + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } else { // Simple case: Certificate has one IP and no subnets - if fp.RemoteIP != h.vpnIp { - f.metrics(incoming).droppedRemoteIP.Inc(1) + //TODO: we can make this more performant + if !slices.Contains(h.vpnAddrs, fp.RemoteAddr) { + f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP } } // Make sure we are supposed to be handling this local ip address - //TODO: this would be better if we had a least specific match lookup, could waste time here, need to benchmark since the algo is different - _, ok := f.localIps.Lookup(fp.LocalIP) + _, ok := f.routableNetworks.Lookup(fp.LocalAddr) if !ok { - f.metrics(incoming).droppedLocalIP.Inc(1) + f.metrics(incoming).droppedLocalAddr.Inc(1) return ErrInvalidLocalIP } @@ -629,7 +635,7 @@ func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedC if ft.UDP.match(p, incoming, c, caPool) { return true } - case firewall.ProtoICMP: + case firewall.ProtoICMP, firewall.ProtoICMPv6: if ft.ICMP.match(p, incoming, c, caPool) { return true } @@ -859,9 +865,9 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool } matched := false - prefix := netip.PrefixFrom(p.RemoteIP, p.RemoteIP.BitLen()) + prefix := netip.PrefixFrom(p.RemoteAddr, p.RemoteAddr.BitLen()) fr.CIDR.EachLookupPrefix(prefix, func(prefix netip.Prefix, val *firewallLocalCIDR) bool { - if prefix.Contains(p.RemoteIP) && val.match(p, c) { + if prefix.Contains(p.RemoteAddr) && val.match(p, c) { matched = true return false } @@ -877,9 +883,14 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } - localIp = f.assignedCIDR + for _, network := range f.assignedNetworks { + flc.LocalCIDR.Insert(network, struct{}{}) + } + return nil + } else if localIp.Bits() == 0 { flc.Any = true + return nil } flc.LocalCIDR.Insert(localIp, struct{}{}) @@ -895,7 +906,7 @@ func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate return true } - _, ok := flc.LocalCIDR.Lookup(p.LocalIP) + _, ok := flc.LocalCIDR.Lookup(p.LocalAddr) return ok } diff --git a/firewall/packet.go b/firewall/packet.go index 8954f4c47..1d8f12a0c 100644 --- a/firewall/packet.go +++ b/firewall/packet.go @@ -9,18 +9,19 @@ import ( type m map[string]interface{} const ( - ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever - ProtoTCP = 6 - ProtoUDP = 17 - ProtoICMP = 1 + ProtoAny = 0 // When we want to handle HOPOPT (0) we can change this, if ever + ProtoTCP = 6 + ProtoUDP = 17 + ProtoICMP = 1 + ProtoICMPv6 = 58 PortAny = 0 // Special value for matching `port: any` PortFragment = -1 // Special value for matching `port: fragment` ) type Packet struct { - LocalIP netip.Addr - RemoteIP netip.Addr + LocalAddr netip.Addr + RemoteAddr netip.Addr LocalPort uint16 RemotePort uint16 Protocol uint8 @@ -29,8 +30,8 @@ type Packet struct { func (fp *Packet) Copy() *Packet { return &Packet{ - LocalIP: fp.LocalIP, - RemoteIP: fp.RemoteIP, + LocalAddr: fp.LocalAddr, + RemoteAddr: fp.RemoteAddr, LocalPort: fp.LocalPort, RemotePort: fp.RemotePort, Protocol: fp.Protocol, @@ -51,8 +52,8 @@ func (fp Packet) MarshalJSON() ([]byte, error) { proto = fmt.Sprintf("unknown %v", fp.Protocol) } return json.Marshal(m{ - "LocalIP": fp.LocalIP.String(), - "RemoteIP": fp.RemoteIP.String(), + "LocalAddr": fp.LocalAddr.String(), + "RemoteAddr": fp.RemoteAddr.String(), "LocalPort": fp.LocalPort, "RemotePort": fp.RemotePort, "Protocol": proto, diff --git a/firewall_test.go b/firewall_test.go index 57cd32ae5..c093a6e7b 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewFirewall(t *testing.T) { @@ -128,8 +129,8 @@ func TestFirewall_Drop(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -149,9 +150,9 @@ func TestFirewall_Drop(t *testing.T) { InvertedGroups: map[string]struct{}{"default-group": {}}, }, }, - vpnIp: netip.MustParseAddr("1.2.3.4"), + vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -166,10 +167,10 @@ func TestFirewall_Drop(t *testing.T) { assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch - oldRemote := p.RemoteIP - p.RemoteIP = netip.MustParseAddr("1.2.3.10") + oldRemote := p.RemoteAddr + p.RemoteAddr = netip.MustParseAddr("1.2.3.10") assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrInvalidRemoteIP) - p.RemoteIP = oldRemote + p.RemoteAddr = oldRemote // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) @@ -235,7 +236,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { } ip := netip.MustParsePrefix("9.254.254.254/32") for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: ip.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: ip.Addr()}, true, c, cp)) } }) @@ -261,7 +262,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { InvertedGroups: map[string]struct{}{"nope": {}}, } for n := 0; n < b.N; n++ { - assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.False(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) @@ -285,7 +286,7 @@ func BenchmarkFirewallTable_match(b *testing.B) { InvertedGroups: map[string]struct{}{"good-group": {}}, } for n := 0; n < b.N; n++ { - assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalIP: pfix.Addr()}, true, c, cp)) + assert.True(b, ft.match(firewall.Packet{Protocol: firewall.ProtoTCP, LocalPort: 100, LocalAddr: pfix.Addr()}, true, c, cp)) } }) @@ -308,8 +309,8 @@ func TestFirewall_Drop2(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -329,9 +330,9 @@ func TestFirewall_Drop2(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(c.Certificate) + h.buildNetworks(c.Certificate) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -345,7 +346,7 @@ func TestFirewall_Drop2(t *testing.T) { peerCert: &c1, }, } - h1.CreateRemoteCIDR(c1.Certificate) + h1.buildNetworks(c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -364,8 +365,8 @@ func TestFirewall_Drop3(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 1, RemotePort: 1, Protocol: firewall.ProtoUDP, @@ -391,9 +392,9 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c1, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h1.CreateRemoteCIDR(c1.Certificate) + h1.buildNetworks(c1.Certificate) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -406,9 +407,9 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c2, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h2.CreateRemoteCIDR(c2.Certificate) + h2.buildNetworks(c2.Certificate) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -421,9 +422,9 @@ func TestFirewall_Drop3(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c3, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h3.CreateRemoteCIDR(c3.Certificate) + h3.buildNetworks(c3.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -446,8 +447,8 @@ func TestFirewall_DropConntrackReload(t *testing.T) { l.SetOutput(ob) p := firewall.Packet{ - LocalIP: netip.MustParseAddr("1.2.3.4"), - RemoteIP: netip.MustParseAddr("1.2.3.4"), + LocalAddr: netip.MustParseAddr("1.2.3.4"), + RemoteAddr: netip.MustParseAddr("1.2.3.4"), LocalPort: 10, RemotePort: 90, Protocol: firewall.ProtoUDP, @@ -468,9 +469,9 @@ func TestFirewall_DropConntrackReload(t *testing.T) { ConnectionState: &ConnectionState{ peerCert: &c, }, - vpnIp: network.Addr(), + vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(c.Certificate) + h.buildNetworks(c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -622,55 +623,58 @@ func TestNewFirewallFromConfig(t *testing.T) { l := test.NewLogger() // Test a bad rule definition c := &dummyCert{} + cs, err := newCertState(cert.Version2, nil, c, false, cert.Curve_CURVE25519, nil) + require.NoError(t, err) + conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} - _, err := NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} - _, err = NewFirewallFromConfig(l, c, conf) + _, err = NewFirewallFromConfig(l, cs, conf) assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } diff --git a/go.mod b/go.mod index f46499099..2ff9976ce 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,6 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/sirupsen/logrus v1.9.3 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e - github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 github.com/stretchr/testify v1.9.0 github.com/vishvananda/netlink v1.3.0 diff --git a/go.sum b/go.sum index dacc3d37e..d0e9c5520 100644 --- a/go.sum +++ b/go.sum @@ -137,8 +137,6 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= -github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6 h1:pnnLyeX7o/5aX8qUQ69P/mLojDqwda8hFOCBTmP/6hw= github.com/stefanberger/go-pkcs11uri v0.0.0-20230803200340-78284954bff6/go.mod h1:39R/xuhNgVhi+K0/zst4TLrJrVmbm6LVgl4A0+ZFS5M= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/handshake_ix.go b/handshake_ix.go index 3add83d88..43c794601 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -2,10 +2,12 @@ package nebula import ( "net/netip" + "slices" "time" "github.com/flynn/noise" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" ) @@ -16,30 +18,60 @@ import ( func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { err := f.handshakeManager.allocateIndex(hh) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to generate index") return false } - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, true, noise.HandshakeIX, []byte{}, 0) - hh.hostinfo.ConnectionState = ci + // If we're connecting to a v6 address we must use a v2 cert + cs := f.pki.getCertState() + v := cs.defaultVersion + for _, a := range hh.hostinfo.vpnAddrs { + if a.Is6() { + v = cert.Version2 + break + } + } - hsProto := &NebulaHandshakeDetails{ - InitiatorIndex: hh.hostinfo.localIndexId, - Time: uint64(time.Now().UnixNano()), - Cert: certState.RawCertificateNoKey, + crt := cs.getCertificate(v) + if crt == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate is available") + return false } - hsBytes := []byte{} + crtHs := cs.getHandshakeBytes(v) + if crtHs == nil { + f.l.WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Unable to handshake with host because no certificate handshake bytes is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, true, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", v). + Error("Failed to create connection state") + return false + } + hh.hostinfo.ConnectionState = ci hs := &NebulaHandshake{ - Details: hsProto, + Details: &NebulaHandshakeDetails{ + InitiatorIndex: hh.hostinfo.localIndexId, + Time: uint64(time.Now().UnixNano()), + Cert: crtHs, + CertVersion: uint32(v), + }, } - hsBytes, err = hs.Marshal() + hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs).WithField("certVersion", v). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to marshal handshake message") return false } @@ -48,7 +80,7 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { msg, _, _, err := ci.H.WriteMessage(h, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hh.hostinfo.vpnIp). + f.l.WithError(err).WithField("vpnAddrs", hh.hostinfo.vpnAddrs). WithField("handshake", m{"stage": 0, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return false } @@ -63,30 +95,44 @@ func ixHandshakeStage0(f *Interface, hh *HandshakeHostInfo) bool { } func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet []byte, h *header.H) { - certState := f.pki.GetCertState() - ci := NewConnectionState(f.l, f.cipher, certState, false, noise.HandshakeIX, []byte{}, 0) + cs := f.pki.getCertState() + crt := cs.GetDefaultCertificate() + if crt == nil { + f.l.WithField("udpAddr", addr). + WithField("handshake", m{"stage": 0, "style": "ix_psk0"}). + WithField("certVersion", cs.defaultVersion). + Error("Unable to handshake with host because no certificate is available") + } + + ci, err := NewConnectionState(f.l, cs, crt, false, noise.HandshakeIX) + if err != nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to create connection state") + return + } + // Mark packet 1 as seen so it doesn't show up as missed ci.window.Update(f.l, 1) msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed to call noise.ReadMessage") return } hs := &NebulaHandshake{} err = hs.Unmarshal(msg) - /* - l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex) - */ if err != nil || hs.Details == nil { f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Error("Failed unmarshal handshake message") return } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -99,6 +145,20 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } + if remoteCert.Certificate.Version() != ci.myCert.Version() { + // We started off using the wrong certificate version, lets see if we can match the version that was sent to us + rc := cs.getCertificate(remoteCert.Certificate.Version()) + if rc == nil { + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).WithField("cert", remoteCert). + Info("Unable to handshake with host due to missing certificate version") + return + } + + // Record the certificate we are actually using + ci.myCert = rc + } + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -111,30 +171,36 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + var vpnAddrs []netip.Addr certName := remoteCert.Certificate.Name() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() - if vpnIp == f.myVpnNet.Addr() { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). - WithField("certName", certName). - WithField("fingerprint", fingerprint). - WithField("issuer", issuer). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") - return - } - - if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + for _, network := range remoteCert.Certificate.Networks() { + vpnAddr := network.Addr() + _, found := f.myVpnAddrsTable.Lookup(vpnAddr) + if found { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Refusing to handshake with myself") return } + + if addr.IsValid() { + if !f.lightHouse.GetRemoteAllowList().Allow(vpnAddr, addr.Addr()) { + f.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + return + } + } + + vpnAddrs = append(vpnAddrs, vpnAddr) } myIndex, err := generateIndex(f.l) if err != nil { - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -146,17 +212,17 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ConnectionState: ci, localIndexId: myIndex, remoteIndexId: hs.Details.InitiatorIndex, - vpnIp: vpnIp, + vpnAddrs: vpnAddrs, HandshakePacket: make(map[uint8][]byte, 0), lastHandshakeTime: hs.Details.Time, relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -165,13 +231,26 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake message received") hs.Details.ResponderIndex = myIndex - hs.Details.Cert = certState.RawCertificateNoKey + hs.Details.Cert = cs.getHandshakeBytes(ci.myCert.Version()) + if hs.Details.Cert == nil { + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). + WithField("certName", certName). + WithField("fingerprint", fingerprint). + WithField("issuer", issuer). + WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). + WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVersion", ci.myCert.Version()). + Error("Unable to handshake with host because no certificate handshake bytes is available") + return + } + + hs.Details.CertVersion = uint32(ci.myCert.Version()) // Update the time in case their clock is way off from ours hs.Details.Time = uint64(time.Now().UnixNano()) hsBytes, err := hs.Marshal() if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -182,14 +261,14 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet nh := header.Encode(make([]byte, header.Len), header.Version, header.Handshake, header.HandshakeIXPSK0, hs.Details.InitiatorIndex, 2) msg, dKey, eKey, err := ci.H.WriteMessage(nh, hsBytes) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.WriteMessage") return } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -213,9 +292,9 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet ci.dKey = NewNebulaCipherState(dKey) ci.eKey = NewNebulaCipherState(eKey) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert.Certificate) + hostinfo.buildNetworks(remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -225,7 +304,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if existing.SetRemoteIfPreferred(f.hostMap, addr) { // Send a test packet to ensure the other side has also switched to // the preferred remote - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) } msg = existing.HandshakePacket[2] @@ -233,11 +312,11 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if addr.IsValid() { err := f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). WithError(err).Error("Failed to send handshake message") } else { - f.l.WithField("vpnIp", existing.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") } @@ -247,16 +326,16 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", existing.vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", existing.vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("cached", true). Info("Handshake message sent") return } case ErrExistingHostInfo: // This means there was an existing tunnel and this handshake was older than the one we are currently based on - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("oldHandshakeTime", existing.lastHandshakeTime). WithField("newHandshakeTime", hostinfo.lastHandshakeTime). @@ -267,23 +346,23 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet Info("Handshake too old") // Send a test packet to trigger an authenticated tunnel test, this should suss out any lingering tunnel issues - f.SendMessageToVpnIp(header.Test, header.TestRequest, vpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + f.SendMessageToVpnAddr(header.Test, header.TestRequest, vpnAddrs[0], []byte(""), make([]byte, 12, 12), make([]byte, mtu)) return case ErrLocalIndexCollision: // This means we failed to insert because of collision on localIndexId. Just let the next handshake packet retry - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). WithField("initiatorIndex", hs.Details.InitiatorIndex).WithField("responderIndex", hs.Details.ResponderIndex). WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). - WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnIp). + WithField("localIndex", hostinfo.localIndexId).WithField("collision", existing.vpnAddrs). Error("Failed to add HostInfo due to localIndex collision") return default: // Shouldn't happen, but just in case someone adds a new error type to CheckAndComplete // And we forget to update it here - f.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -299,7 +378,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet if addr.IsValid() { err = f.outside.WriteTo(msg, addr) if err != nil { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -307,7 +386,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet WithField("remoteIndex", h.RemoteIndex).WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). WithError(err).Error("Failed to send handshake") } else { - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -320,9 +399,12 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet f.l.Error("Handshake send failed: both addr and via are nil.") return } - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + // I successfully received a handshake. Just in case I marked this tunnel as 'Disestablished', ensure + // it's correctly marked as working. + via.relayHI.relayState.UpdateRelayForByIdxState(via.remoteIdx, Established) f.SendVia(via.relayHI, via.relay, msg, make([]byte, 12), make([]byte, mtu), false) - f.l.WithField("vpnIp", vpnIp).WithField("relay", via.relayHI.vpnIp). + f.l.WithField("vpnAddrs", vpnAddrs).WithField("relay", via.relayHI.vpnAddrs[0]). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -349,8 +431,9 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hostinfo := hh.hostinfo if addr.IsValid() { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.Addr()) { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") + //TODO: this is kind of nonsense now + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], addr.Addr()) { + f.l.WithField("vpnIp", hostinfo.vpnAddrs).WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake") return false } } @@ -358,7 +441,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci := hostinfo.ConnectionState msg, eKey, dKey, err := ci.H.ReadMessage(nil, packet[header.Len:]) if err != nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).WithField("header", h). Error("Failed to call noise.ReadMessage") @@ -367,7 +450,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha // near future return false } else if dKey == nil || eKey == nil { - f.l.WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Error("Noise did not arrive at a key") @@ -379,16 +462,16 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha hs := &NebulaHandshake{} err = hs.Unmarshal(msg) if err != nil || hs.Details == nil { - f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}).Error("Failed unmarshal handshake message") // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again return true } - remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.pki.GetCAPool()) + remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) if err != nil { - e := f.l.WithError(err).WithField("vpnIp", hostinfo.vpnIp).WithField("udpAddr", addr). + e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) if f.l.Level > logrus.DebugLevel { @@ -413,7 +496,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + vpnNetworks := remoteCert.Certificate.Networks() certName := remoteCert.Certificate.Name() fingerprint := remoteCert.Fingerprint issuer := remoteCert.Certificate.Issuer() @@ -430,12 +513,17 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha if addr.IsValid() { hostinfo.SetRemote(addr) } else { - hostinfo.relayState.InsertRelayTo(via.relayHI.vpnIp) + hostinfo.relayState.InsertRelayTo(via.relayHI.vpnAddrs[0]) + } + + vpnAddrs := make([]netip.Addr, len(vpnNetworks)) + for i, n := range vpnNetworks { + vpnAddrs[i] = n.Addr() } // Ensure the right host responded - if vpnIp != hostinfo.vpnIp { - f.l.WithField("intendedVpnIp", hostinfo.vpnIp).WithField("haveVpnIp", vpnIp). + if !slices.Contains(vpnAddrs, hostinfo.vpnAddrs[0]) { + f.l.WithField("intendedVpnAddrs", hostinfo.vpnAddrs).WithField("haveVpnNetworks", vpnNetworks). WithField("udpAddr", addr).WithField("certName", certName). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). Info("Incorrect host responded to handshake") @@ -444,14 +532,14 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha f.handshakeManager.DeleteHostInfo(hostinfo) // Create a new hostinfo/handshake for the intended vpn ip - f.handshakeManager.StartHandshake(hostinfo.vpnIp, func(newHH *HandshakeHostInfo) { + f.handshakeManager.StartHandshake(hostinfo.vpnAddrs[0], func(newHH *HandshakeHostInfo) { //TODO: this doesnt know if its being added or is being used for caching a packet // Block the current used address newHH.hostinfo.remotes = hostinfo.remotes newHH.hostinfo.remotes.BlockRemote(addr) f.l.WithField("blockedUdpAddrs", newHH.hostinfo.remotes.CopyBlockedRemotes()). - WithField("vpnIp", newHH.hostinfo.vpnIp). + WithField("vpnNetworks", vpnNetworks). WithField("remotes", newHH.hostinfo.remotes.CopyAddrs(f.hostMap.GetPreferredRanges())). Info("Blocked addresses for handshakes") @@ -459,11 +547,8 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha newHH.packetStore = hh.packetStore hh.packetStore = []*cachedPacket{} - // Get the correct remote list for the host we did handshake with - hostinfo.SetRemote(addr) - hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) - // Finally, put the correct vpn ip in the host info, tell them to close the tunnel, and return true to tear down - hostinfo.vpnIp = vpnIp + // Finally, put the correct vpn addrs in the host info, tell them to close the tunnel, and return true to tear down + hostinfo.vpnAddrs = vpnAddrs f.sendCloseTunnel(hostinfo) }) @@ -474,7 +559,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha ci.window.Update(f.l, 2) duration := time.Since(hh.startTime).Nanoseconds() - f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). + f.l.WithField("vpnNetworks", vpnNetworks).WithField("udpAddr", addr). WithField("certName", certName). WithField("fingerprint", fingerprint). WithField("issuer", issuer). @@ -485,7 +570,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha Info("Handshake message received") // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert.Certificate) + hostinfo.buildNetworks(remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) diff --git a/handshake_manager.go b/handshake_manager.go index 48348939b..0e5c4fae7 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -13,6 +13,7 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" ) @@ -118,18 +119,18 @@ func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *Lig } } -func (c *HandshakeManager) Run(ctx context.Context) { - clockSource := time.NewTicker(c.config.tryInterval) +func (hm *HandshakeManager) Run(ctx context.Context) { + clockSource := time.NewTicker(hm.config.tryInterval) defer clockSource.Stop() for { select { case <-ctx.Done(): return - case vpnIP := <-c.trigger: - c.handleOutbound(vpnIP, true) + case vpnIP := <-hm.trigger: + hm.handleOutbound(vpnIP, true) case now := <-clockSource.C: - c.NextOutboundHandshakeTimerTick(now) + hm.NextOutboundHandshakeTimerTick(now) } } } @@ -159,14 +160,14 @@ func (hm *HandshakeManager) HandleIncoming(addr netip.AddrPort, via *ViaSender, } } -func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { - c.OutboundHandshakeTimer.Advance(now) +func (hm *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) { + hm.OutboundHandshakeTimer.Advance(now) for { - vpnIp, has := c.OutboundHandshakeTimer.Purge() + vpnIp, has := hm.OutboundHandshakeTimer.Purge() if !has { break } - c.handleOutbound(vpnIp, false) + hm.handleOutbound(vpnIp, false) } } @@ -208,7 +209,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered // NB ^ This comment doesn't jive. It's how the thing gets initialized. // It's the common path. Should it update every time, in case a future LH query/queries give us more info? if hostinfo.remotes == nil { - hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp) + hostinfo.remotes = hm.lightHouse.QueryCache([]netip.Addr{vpnIp}) } remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges()) @@ -267,59 +268,26 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts") // Send a RelayRequest to all known Relay IP's for _, relay := range hostinfo.remotes.relays { - // Don't relay to myself, and don't relay through the host I'm trying to connect to - if relay == vpnIp || relay == hm.lightHouse.myVpnNet.Addr() { + // Don't relay to myself + if relay == vpnIp { continue } - relayHostInfo := hm.mainHostMap.QueryVpnIp(relay) + + // Don't relay through the host I'm trying to connect to + _, found := hm.f.myVpnAddrsTable.Lookup(relay) + if found { + continue + } + + relayHostInfo := hm.mainHostMap.QueryVpnAddr(relay) if relayHostInfo == nil || !relayHostInfo.remote.IsValid() { hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target") hm.f.Handshake(relay) continue } - // Check the relay HostInfo to see if we already established a relay through it - if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok { - switch existingRelay.State { - case Established: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") - hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) - case Requested: - hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") - - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - - // Re-send the CreateRelay request, in case the previous one was lost. - m := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: existingRelay.LocalIndex, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), - } - msg, err := m.Marshal() - if err != nil { - hostinfo.logger(hm.l). - WithError(err). - Error("Failed to marshal Control message to create relay") - } else { - // This must send over the hostinfo, not over hm.Hosts[ip] - hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) - hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), - "relayTo": vpnIp, - "initiatorRelayIndex": existingRelay.LocalIndex, - "relay": relay}). - Info("send CreateRelayRequest") - } - default: - hostinfo.logger(hm.l). - WithField("vpnIp", vpnIp). - WithField("state", existingRelay.State). - WithField("relay", relayHostInfo.vpnIp). - Errorf("Relay unexpected state") - } - } else { + // Check the relay HostInfo to see if we already established a relay through + existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp) + if !ok { // No relays exist or requested yet. if relayHostInfo.remote.IsValid() { idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested) @@ -327,16 +295,35 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap") } - //TODO: IPV6-WORK - myVpnIpB := hm.f.myVpnNet.Addr().As4() - theirVpnIpB := vpnIp.As4() - m := NebulaControl{ Type: NebulaControl_CreateRelayRequest, InitiatorRelayIndex: idx, - RelayFromIp: binary.BigEndian.Uint32(myVpnIpB[:]), - RelayToIp: binary.BigEndian.Uint32(theirVpnIpB[:]), } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() if err != nil { hostinfo.logger(hm.l). @@ -345,13 +332,80 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered } else { hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) hm.l.WithFields(logrus.Fields{ - "relayFrom": hm.f.myVpnNet.Addr(), + "relayFrom": hm.f.myVpnAddrs[0], "relayTo": vpnIp, "initiatorRelayIndex": idx, "relay": relay}). Info("send CreateRelayRequest") } } + continue + } + + switch existingRelay.State { + case Established: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay") + hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false) + case Disestablished: + // Mark this relay as 'requested' + relayHostInfo.relayState.UpdateRelayForByIpState(vpnIp, Requested) + fallthrough + case Requested: + hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request") + // Re-send the CreateRelay request, in case the previous one was lost. + m := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: existingRelay.LocalIndex, + } + + switch relayHostInfo.GetCert().Certificate.Version() { + case cert.Version1: + if !hm.f.myVpnAddrs[0].Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 network because the relay is not running a current nebula version") + continue + } + + if !vpnIp.Is4() { + hostinfo.logger(hm.l).Error("can not establish v1 relay with a v6 remote network because the relay is not running a current nebula version") + continue + } + + b := hm.f.myVpnAddrs[0].As4() + m.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = vpnIp.As4() + m.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + case cert.Version2: + m.RelayFromAddr = netAddrToProtoAddr(hm.f.myVpnAddrs[0]) + m.RelayToAddr = netAddrToProtoAddr(vpnIp) + default: + hostinfo.logger(hm.l).Error("Unknown certificate version found while creating relay") + continue + } + msg, err := m.Marshal() + if err != nil { + hostinfo.logger(hm.l). + WithError(err). + Error("Failed to marshal Control message to create relay") + } else { + // This must send over the hostinfo, not over hm.Hosts[ip] + hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu)) + hm.l.WithFields(logrus.Fields{ + "relayFrom": hm.f.myVpnAddrs[0], + "relayTo": vpnIp, + "initiatorRelayIndex": existingRelay.LocalIndex, + "relay": relay}). + Info("send CreateRelayRequest") + } + case PeerRequested: + // PeerRequested only occurs in Forwarding relays, not Terminal relays, and this is a Terminal relay case. + fallthrough + default: + hostinfo.logger(hm.l). + WithField("vpnIp", vpnIp). + WithField("state", existingRelay.State). + WithField("relay", relay). + Errorf("Relay unexpected state") + } } } @@ -381,10 +435,10 @@ func (hm *HandshakeManager) GetOrHandshake(vpnIp netip.Addr, cacheCb func(*Hands } // StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip -func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { +func (hm *HandshakeManager) StartHandshake(vpnAddr netip.Addr, cacheCb func(*HandshakeHostInfo)) *HostInfo { hm.Lock() - if hh, ok := hm.vpnIps[vpnIp]; ok { + if hh, ok := hm.vpnIps[vpnAddr]; ok { // We are already trying to handshake with this vpn ip if cacheCb != nil { cacheCb(hh) @@ -394,12 +448,12 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands } hostinfo := &HostInfo{ - vpnIp: vpnIp, + vpnAddrs: []netip.Addr{vpnAddr}, HandshakePacket: make(map[uint8][]byte, 0), relayState: RelayState{ - relays: map[netip.Addr]struct{}{}, - relayForByIp: map[netip.Addr]*Relay{}, - relayForByIdx: map[uint32]*Relay{}, + relays: map[netip.Addr]struct{}{}, + relayForByAddr: map[netip.Addr]*Relay{}, + relayForByIdx: map[uint32]*Relay{}, }, } @@ -407,9 +461,9 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands hostinfo: hostinfo, startTime: time.Now(), } - hm.vpnIps[vpnIp] = hh + hm.vpnIps[vpnAddr] = hh hm.metricInitiated.Inc(1) - hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval) + hm.OutboundHandshakeTimer.Add(vpnAddr, hm.config.tryInterval) if cacheCb != nil { cacheCb(hh) @@ -417,21 +471,21 @@ func (hm *HandshakeManager) StartHandshake(vpnIp netip.Addr, cacheCb func(*Hands // If this is a static host, we don't need to wait for the HostQueryReply // We can trigger the handshake right now - _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp] + _, doTrigger := hm.lightHouse.GetStaticHostList()[vpnAddr] if !doTrigger { // Add any calculated remotes, and trigger early handshake if one found - doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp) + doTrigger = hm.lightHouse.addCalculatedRemotes(vpnAddr) } if doTrigger { select { - case hm.trigger <- vpnIp: + case hm.trigger <- vpnAddr: default: } } hm.Unlock() - hm.lightHouse.QueryServer(vpnIp) + hm.lightHouse.QueryServer(vpnAddr) return hostinfo } @@ -452,14 +506,14 @@ var ( // // ErrLocalIndexCollision if we already have an entry in the main or pending // hostmap for the hostinfo.localIndexId. -func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { - c.mainHostMap.Lock() - defer c.mainHostMap.Unlock() - c.Lock() - defer c.Unlock() +func (hm *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) { + hm.mainHostMap.Lock() + defer hm.mainHostMap.Unlock() + hm.Lock() + defer hm.Unlock() // Check if we already have a tunnel with this vpn ip - existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp] + existingHostInfo, found := hm.mainHostMap.Hosts[hostinfo.vpnAddrs[0]] if found && existingHostInfo != nil { testHostInfo := existingHostInfo for testHostInfo != nil { @@ -476,31 +530,31 @@ func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket return existingHostInfo, ErrExistingHostInfo } - existingHostInfo.logger(c.l).Info("Taking new handshake") + existingHostInfo.logger(hm.l).Info("Taking new handshake") } - existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId] + existingIndex, found := hm.mainHostMap.Indexes[hostinfo.localIndexId] if found { // We have a collision, but for a different hostinfo return existingIndex, ErrLocalIndexCollision } - existingPendingIndex, found := c.indexes[hostinfo.localIndexId] + existingPendingIndex, found := hm.indexes[hostinfo.localIndexId] if found && existingPendingIndex.hostinfo != hostinfo { // We have a collision, but for a different hostinfo return existingPendingIndex.hostinfo, ErrLocalIndexCollision } - existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] - if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp { + existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId] + if found && existingRemoteIndex != nil && existingRemoteIndex.vpnAddrs[0] != hostinfo.vpnAddrs[0] { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. - hostinfo.logger(c.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + hostinfo.logger(hm.l). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } - c.mainHostMap.unlockedAddHostInfo(hostinfo, f) + hm.mainHostMap.unlockedAddHostInfo(hostinfo, f) return existingHostInfo, nil } @@ -518,7 +572,7 @@ func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) { // We have a collision, but this can happen since we can't control // the remote ID. Just log about the situation as a note. hostinfo.logger(hm.l). - WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp). + WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnAddrs). Info("New host shadows existing host remoteIndex") } @@ -555,31 +609,34 @@ func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error { return errors.New("failed to generate unique localIndexId") } -func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { - c.Lock() - defer c.Unlock() - c.unlockedDeleteHostInfo(hostinfo) +func (hm *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) { + hm.Lock() + defer hm.Unlock() + hm.unlockedDeleteHostInfo(hostinfo) } -func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { - delete(c.vpnIps, hostinfo.vpnIp) - if len(c.vpnIps) == 0 { - c.vpnIps = map[netip.Addr]*HandshakeHostInfo{} +func (hm *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) { + for _, addr := range hostinfo.vpnAddrs { + delete(hm.vpnIps, addr) } - delete(c.indexes, hostinfo.localIndexId) - if len(c.vpnIps) == 0 { - c.indexes = map[uint32]*HandshakeHostInfo{} + if len(hm.vpnIps) == 0 { + hm.vpnIps = map[netip.Addr]*HandshakeHostInfo{} } - if c.l.Level >= logrus.DebugLevel { - c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + delete(hm.indexes, hostinfo.localIndexId) + if len(hm.indexes) == 0 { + hm.indexes = map[uint32]*HandshakeHostInfo{} + } + + if hm.l.Level >= logrus.DebugLevel { + hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.vpnIps), + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Pending hostmap hostInfo deleted") } } -func (hm *HandshakeManager) QueryVpnIp(vpnIp netip.Addr) *HostInfo { +func (hm *HandshakeManager) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { hh := hm.queryVpnIp(vpnIp) if hh != nil { return hh.hostinfo @@ -608,37 +665,37 @@ func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo { return hm.indexes[index] } -func (c *HandshakeManager) GetPreferredRanges() []netip.Prefix { - return c.mainHostMap.GetPreferredRanges() +func (hm *HandshakeManager) GetPreferredRanges() []netip.Prefix { + return hm.mainHostMap.GetPreferredRanges() } -func (c *HandshakeManager) ForEachVpnIp(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachVpnAddr(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.vpnIps { + for _, v := range hm.vpnIps { f(v.hostinfo) } } -func (c *HandshakeManager) ForEachIndex(f controlEach) { - c.RLock() - defer c.RUnlock() +func (hm *HandshakeManager) ForEachIndex(f controlEach) { + hm.RLock() + defer hm.RUnlock() - for _, v := range c.indexes { + for _, v := range hm.indexes { f(v.hostinfo) } } -func (c *HandshakeManager) EmitStats() { - c.RLock() - hostLen := len(c.vpnIps) - indexLen := len(c.indexes) - c.RUnlock() +func (hm *HandshakeManager) EmitStats() { + hm.RLock() + hostLen := len(hm.vpnIps) + indexLen := len(hm.indexes) + hm.RUnlock() metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen)) metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen)) - c.mainHostMap.EmitStats() + hm.mainHostMap.EmitStats() } // Utility functions below diff --git a/handshake_manager_test.go b/handshake_manager_test.go index daa867564..7edc55b9c 100644 --- a/handshake_manager_test.go +++ b/handshake_manager_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" @@ -13,21 +14,20 @@ import ( func Test_NewHandshakeManagerVpnIp(t *testing.T) { l := test.NewLogger() - vpncidr := netip.MustParsePrefix("172.1.1.1/24") localrange := netip.MustParsePrefix("10.1.1.1/24") ip := netip.MustParseAddr("172.1.1.2") preferredRanges := []netip.Prefix{localrange} - mainHM := newHostMap(l, vpncidr) + mainHM := newHostMap(l) mainHM.preferredRanges.Store(&preferredRanges) lh := newTestLighthouse() cs := &CertState{ - RawCertificate: []byte{}, - PrivateKey: []byte{}, - Certificate: &dummyCert{}, - RawCertificateNoKey: []byte{}, + defaultVersion: cert.Version1, + privateKey: []byte{}, + v1Cert: &dummyCert{version: cert.Version1}, + v1HandshakeBytes: []byte{}, } blah := NewHandshakeManager(l, mainHM, lh, &udp.NoopConn{}, defaultHandshakeConfig) @@ -41,7 +41,7 @@ func Test_NewHandshakeManagerVpnIp(t *testing.T) { i2 := blah.StartHandshake(ip, nil) assert.Same(t, i, i2) - i.remotes = NewRemoteList(nil) + i.remotes = NewRemoteList([]netip.Addr{}, nil) // Adding something to pending should not affect the main hostmap assert.Len(t, mainHM.Hosts, 0) @@ -79,16 +79,24 @@ func testCountTimerWheelEntries(tw *LockingTimerWheel[netip.Addr]) (c int) { type mockEncWriter struct { } -func (mw *mockEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToVpnAddr(_ header.MessageType, _ header.MessageSubType, _ netip.Addr, _, _, _ []byte) { return } -func (mw *mockEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { +func (mw *mockEncWriter) SendVia(_ *HostInfo, _ *Relay, _, _, _ []byte, _ bool) { return } -func (mw *mockEncWriter) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { +func (mw *mockEncWriter) SendMessageToHostInfo(_ header.MessageType, _ header.MessageSubType, _ *HostInfo, _, _, _ []byte) { return } -func (mw *mockEncWriter) Handshake(vpnIP netip.Addr) {} +func (mw *mockEncWriter) Handshake(_ netip.Addr) {} + +func (mw *mockEncWriter) GetHostInfo(_ netip.Addr) *HostInfo { + return nil +} + +func (mw *mockEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: cert.Version2} +} diff --git a/hostmap.go b/hostmap.go index d83151eb3..99aeb53bf 100644 --- a/hostmap.go +++ b/hostmap.go @@ -35,6 +35,7 @@ const ( Requested = iota PeerRequested Established + Disestablished ) const ( @@ -48,7 +49,7 @@ type Relay struct { State int LocalIndex uint32 RemoteIndex uint32 - PeerIp netip.Addr + PeerAddr netip.Addr } type HostMap struct { @@ -58,7 +59,6 @@ type HostMap struct { RemoteIndexes map[uint32]*HostInfo Hosts map[netip.Addr]*HostInfo preferredRanges atomic.Pointer[[]netip.Prefix] - vpnCIDR netip.Prefix l *logrus.Logger } @@ -68,9 +68,12 @@ type HostMap struct { type RelayState struct { sync.RWMutex - relays map[netip.Addr]struct{} // Set of VpnIp's of Hosts to use as relays to access this peer - relayForByIp map[netip.Addr]*Relay // Maps VpnIps of peers for which this HostInfo is a relay to some Relay info - relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info + relays map[netip.Addr]struct{} // Set of vpnAddr's of Hosts to use as relays to access this peer + // For data race avoidance, the contents of a *Relay are treated immutably. To update a *Relay, copy the existing data, + // modify what needs to be updated, and store the new modified copy in the relayForByIp and relayForByIdx maps (with + // the RelayState Lock held) + relayForByAddr map[netip.Addr]*Relay // Maps vpnAddr of peers for which this HostInfo is a relay to some Relay info + relayForByIdx map[uint32]*Relay // Maps a local index to some Relay info } func (rs *RelayState) DeleteRelay(ip netip.Addr) { @@ -79,6 +82,28 @@ func (rs *RelayState) DeleteRelay(ip netip.Addr) { delete(rs.relays, ip) } +func (rs *RelayState) UpdateRelayForByIpState(vpnIp netip.Addr, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByAddr[vpnIp]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + +func (rs *RelayState) UpdateRelayForByIdxState(idx uint32, state int) { + rs.Lock() + defer rs.Unlock() + if r, ok := rs.relayForByIdx[idx]; ok { + newRelay := *r + newRelay.State = state + rs.relayForByAddr[newRelay.PeerAddr] = &newRelay + rs.relayForByIdx[newRelay.LocalIndex] = &newRelay + } +} + func (rs *RelayState) CopyAllRelayFor() []*Relay { rs.RLock() defer rs.RUnlock() @@ -89,10 +114,10 @@ func (rs *RelayState) CopyAllRelayFor() []*Relay { return ret } -func (rs *RelayState) GetRelayForByIp(ip netip.Addr) (*Relay, bool) { +func (rs *RelayState) GetRelayForByAddr(addr netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[ip] + r, ok := rs.relayForByAddr[addr] return r, ok } @@ -115,8 +140,8 @@ func (rs *RelayState) CopyRelayIps() []netip.Addr { func (rs *RelayState) CopyRelayForIps() []netip.Addr { rs.RLock() defer rs.RUnlock() - currentRelays := make([]netip.Addr, 0, len(rs.relayForByIp)) - for relayIp := range rs.relayForByIp { + currentRelays := make([]netip.Addr, 0, len(rs.relayForByAddr)) + for relayIp := range rs.relayForByAddr { currentRelays = append(currentRelays, relayIp) } return currentRelays @@ -135,7 +160,7 @@ func (rs *RelayState) CopyRelayForIdxs() []uint32 { func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool { rs.Lock() defer rs.Unlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] if !ok { return false } @@ -143,7 +168,7 @@ func (rs *RelayState) CompleteRelayByIP(vpnIp netip.Addr, remoteIdx uint32) bool newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return true } @@ -158,14 +183,14 @@ func (rs *RelayState) CompleteRelayByIdx(localIdx uint32, remoteIdx uint32) (*Re newRelay.State = Established newRelay.RemoteIndex = remoteIdx rs.relayForByIdx[r.LocalIndex] = &newRelay - rs.relayForByIp[r.PeerIp] = &newRelay + rs.relayForByAddr[r.PeerAddr] = &newRelay return &newRelay, true } func (rs *RelayState) QueryRelayForByIp(vpnIp netip.Addr) (*Relay, bool) { rs.RLock() defer rs.RUnlock() - r, ok := rs.relayForByIp[vpnIp] + r, ok := rs.relayForByAddr[vpnIp] return r, ok } @@ -179,7 +204,7 @@ func (rs *RelayState) QueryRelayForByIdx(idx uint32) (*Relay, bool) { func (rs *RelayState) InsertRelay(ip netip.Addr, idx uint32, r *Relay) { rs.Lock() defer rs.Unlock() - rs.relayForByIp[ip] = r + rs.relayForByAddr[ip] = r rs.relayForByIdx[idx] = r } @@ -190,10 +215,12 @@ type HostInfo struct { ConnectionState *ConnectionState remoteIndexId uint32 localIndexId uint32 - vpnIp netip.Addr + vpnAddrs []netip.Addr recvError atomic.Uint32 - remoteCidr *bart.Table[struct{}] - relayState RelayState + + // networks are both all vpn and unsafe networks assigned to this host + networks *bart.Table[struct{}] + relayState RelayState // HandshakePacket records the packets used to create this hostinfo // We need these to avoid replayed handshake packets creating new hostinfos which causes churn @@ -241,28 +268,26 @@ type cachedPacketMetrics struct { dropped metrics.Counter } -func NewHostMapFromConfig(l *logrus.Logger, vpnCIDR netip.Prefix, c *config.C) *HostMap { - hm := newHostMap(l, vpnCIDR) +func NewHostMapFromConfig(l *logrus.Logger, c *config.C) *HostMap { + hm := newHostMap(l) hm.reload(c, true) c.RegisterReloadCallback(func(c *config.C) { hm.reload(c, false) }) - l.WithField("network", hm.vpnCIDR.String()). - WithField("preferredRanges", hm.GetPreferredRanges()). + l.WithField("preferredRanges", hm.GetPreferredRanges()). Info("Main HostMap created") return hm } -func newHostMap(l *logrus.Logger, vpnCIDR netip.Prefix) *HostMap { +func newHostMap(l *logrus.Logger) *HostMap { return &HostMap{ Indexes: map[uint32]*HostInfo{}, Relays: map[uint32]*HostInfo{}, RemoteIndexes: map[uint32]*HostInfo{}, Hosts: map[netip.Addr]*HostInfo{}, - vpnCIDR: vpnCIDR, l: l, } } @@ -305,17 +330,6 @@ func (hm *HostMap) EmitStats() { metrics.GetOrRegisterGauge("hostmap.main.relayIndexes", nil).Update(int64(relaysLen)) } -func (hm *HostMap) RemoveRelay(localIdx uint32) { - hm.Lock() - _, ok := hm.Relays[localIdx] - if !ok { - hm.Unlock() - return - } - delete(hm.Relays, localIdx) - hm.Unlock() -} - // DeleteHostInfo will fully unlink the hostinfo and return true if it was the final hostinfo for this vpn ip func (hm *HostMap) DeleteHostInfo(hostinfo *HostInfo) bool { // Delete the host itself, ensuring it's not modified anymore @@ -335,7 +349,9 @@ func (hm *HostMap) MakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { - oldHostinfo := hm.Hosts[hostinfo.vpnIp] + //TODO: we may need to promote follow on hostinfos from these vpnAddrs as well since their oldHostinfo might not be the same as this one + // this really looks like an ideal spot for memory leaks + oldHostinfo := hm.Hosts[hostinfo.vpnAddrs[0]] if oldHostinfo == hostinfo { return } @@ -348,7 +364,7 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { hostinfo.next.prev = hostinfo.prev } - hm.Hosts[hostinfo.vpnIp] = hostinfo + hm.Hosts[hostinfo.vpnAddrs[0]] = hostinfo if oldHostinfo == nil { return @@ -360,23 +376,36 @@ func (hm *HostMap) unlockedMakePrimary(hostinfo *HostInfo) { } func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { - primary, ok := hm.Hosts[hostinfo.vpnIp] + for _, addr := range hostinfo.vpnAddrs { + h := hm.Hosts[addr] + for h != nil { + if h == hostinfo { + hm.unlockedInnerDeleteHostInfo(h, addr) + } + h = h.next + } + } +} + +func (hm *HostMap) unlockedInnerDeleteHostInfo(hostinfo *HostInfo, addr netip.Addr) { + primary, ok := hm.Hosts[addr] + isLastHostinfo := hostinfo.next == nil && hostinfo.prev == nil if ok && primary == hostinfo { - // The vpnIp pointer points to the same hostinfo as the local index id, we can remove it - delete(hm.Hosts, hostinfo.vpnIp) + // The vpn addr pointer points to the same hostinfo as the local index id, we can remove it + delete(hm.Hosts, addr) if len(hm.Hosts) == 0 { hm.Hosts = map[netip.Addr]*HostInfo{} } if hostinfo.next != nil { - // We had more than 1 hostinfo at this vpnip, promote the next in the list to primary - hm.Hosts[hostinfo.vpnIp] = hostinfo.next + // We had more than 1 hostinfo at this vpn addr, promote the next in the list to primary + hm.Hosts[addr] = hostinfo.next // It is primary, there is no previous hostinfo now hostinfo.next.prev = nil } } else { - // Relink if we were in the middle of multiple hostinfos for this vpn ip + // Relink if we were in the middle of multiple hostinfos for this vpn addr if hostinfo.prev != nil { hostinfo.prev.next = hostinfo.next } @@ -406,10 +435,16 @@ func (hm *HostMap) unlockedDeleteHostInfo(hostinfo *HostInfo) { if hm.l.Level >= logrus.DebugLevel { hm.l.WithField("hostMap", m{"mapTotalSize": len(hm.Hosts), - "vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). + "vpnAddrs": hostinfo.vpnAddrs, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}). Debug("Hostmap hostInfo deleted") } + if isLastHostinfo { + // I have lost connectivity to my peers. My relay tunnel is likely broken. Mark the next + // hops as 'Requested' so that new relay tunnels are created in the future. + hm.unlockedDisestablishVpnAddrRelayFor(hostinfo) + } + // Clean up any local relay indexes for which I am acting as a relay hop for _, localRelayIdx := range hostinfo.relayState.CopyRelayForIdxs() { delete(hm.Relays, localRelayIdx) } @@ -448,11 +483,11 @@ func (hm *HostMap) QueryReverseIndex(index uint32) *HostInfo { } } -func (hm *HostMap) QueryVpnIp(vpnIp netip.Addr) *HostInfo { - return hm.queryVpnIp(vpnIp, nil) +func (hm *HostMap) QueryVpnAddr(vpnIp netip.Addr) *HostInfo { + return hm.queryVpnAddr(vpnIp, nil) } -func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { +func (hm *HostMap) QueryVpnAddrsRelayFor(targetIps []netip.Addr, relayHostIp netip.Addr) (*HostInfo, *Relay, error) { hm.RLock() defer hm.RUnlock() @@ -460,17 +495,42 @@ func (hm *HostMap) QueryVpnIpRelayFor(targetIp, relayHostIp netip.Addr) (*HostIn if !ok { return nil, nil, errors.New("unable to find host") } + for h != nil { - r, ok := h.relayState.QueryRelayForByIp(targetIp) - if ok && r.State == Established { - return h, r, nil + for _, targetIp := range targetIps { + r, ok := h.relayState.QueryRelayForByIp(targetIp) + if ok && r.State == Established { + return h, r, nil + } } h = h.next } + return nil, nil, errors.New("unable to find host with relay") } -func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { +func (hm *HostMap) unlockedDisestablishVpnAddrRelayFor(hi *HostInfo) { + for _, relayHostIp := range hi.relayState.CopyRelayIps() { + if h, ok := hm.Hosts[relayHostIp]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + for _, rs := range hi.relayState.CopyAllRelayFor() { + if rs.Type == ForwardingType { + if h, ok := hm.Hosts[rs.PeerAddr]; ok { + for h != nil { + h.relayState.UpdateRelayForByIpState(hi.vpnAddrs[0], Disestablished) + h = h.next + } + } + } + } +} + +func (hm *HostMap) queryVpnAddr(vpnIp netip.Addr, promoteIfce *Interface) *HostInfo { hm.RLock() if h, ok := hm.Hosts[vpnIp]; ok { hm.RUnlock() @@ -491,25 +551,30 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) + dnsR.Add(remoteCert.Certificate.Name()+".", hostinfo.vpnAddrs) } - - existing := hm.Hosts[hostinfo.vpnIp] - hm.Hosts[hostinfo.vpnIp] = hostinfo - - if existing != nil { - hostinfo.next = existing - existing.prev = hostinfo + for _, addr := range hostinfo.vpnAddrs { + hm.unlockedInnerAddHostInfo(addr, hostinfo, f) } hm.Indexes[hostinfo.localIndexId] = hostinfo hm.RemoteIndexes[hostinfo.remoteIndexId] = hostinfo if hm.l.Level >= logrus.DebugLevel { - hm.l.WithField("hostMap", m{"vpnIp": hostinfo.vpnIp, "mapTotalSize": len(hm.Hosts), - "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "hostId": hostinfo.vpnIp}}). + hm.l.WithField("hostMap", m{"vpnAddrs": hostinfo.vpnAddrs, "mapTotalSize": len(hm.Hosts), + "hostinfo": m{"existing": true, "localIndexId": hostinfo.localIndexId, "vpnAddrs": hostinfo.vpnAddrs}}). Debug("Hostmap vpnIp added") } +} + +func (hm *HostMap) unlockedInnerAddHostInfo(vpnAddr netip.Addr, hostinfo *HostInfo, f *Interface) { + existing := hm.Hosts[vpnAddr] + hm.Hosts[vpnAddr] = hostinfo + + if existing != nil && existing != hostinfo { + hostinfo.next = existing + existing.prev = hostinfo + } i := 1 check := hostinfo @@ -527,7 +592,7 @@ func (hm *HostMap) GetPreferredRanges() []netip.Prefix { return *hm.preferredRanges.Load() } -func (hm *HostMap) ForEachVpnIp(f controlEach) { +func (hm *HostMap) ForEachVpnAddr(f controlEach) { hm.RLock() defer hm.RUnlock() @@ -581,7 +646,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } i.nextLHQuery.Store(now + ifce.reQueryWait.Load()) - ifce.lightHouse.QueryServer(i.vpnIp) + ifce.lightHouse.QueryServer(i.vpnAddrs[0]) } } @@ -596,7 +661,7 @@ func (i *HostInfo) SetRemote(remote netip.AddrPort) { // We copy here because we likely got this remote from a source that reuses the object if i.remote != remote { i.remote = remote - i.remotes.LearnRemote(i.vpnIp, remote) + i.remotes.LearnRemote(i.vpnAddrs[0], remote) } } @@ -647,21 +712,20 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { +func (i *HostInfo) buildNetworks(c cert.Certificate) { if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed return } - remoteCidr := new(bart.Table[struct{}]) + i.networks = new(bart.Table[struct{}]) for _, network := range c.Networks() { - remoteCidr.Insert(network, struct{}{}) + i.networks.Insert(network, struct{}{}) } for _, network := range c.UnsafeNetworks() { - remoteCidr.Insert(network, struct{}{}) + i.networks.Insert(network, struct{}{}) } - i.remoteCidr = remoteCidr } func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { @@ -669,7 +733,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { return logrus.NewEntry(l) } - li := l.WithField("vpnIp", i.vpnIp). + li := l.WithField("vpnAddrs", i.vpnAddrs). WithField("localIndex", i.localIndexId). WithField("remoteIndex", i.remoteIndexId) @@ -684,9 +748,9 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { // Utility functions -func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { +func localAddrs(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { //FIXME: This function is pretty garbage - var ips []netip.Addr + var finalAddrs []netip.Addr ifaces, _ := net.Interfaces() for _, i := range ifaces { allow := allowList.AllowName(i.Name) @@ -698,39 +762,38 @@ func localIps(l *logrus.Logger, allowList *LocalAllowList) []netip.Addr { continue } addrs, _ := i.Addrs() - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { + for _, rawAddr := range addrs { + var addr netip.Addr + switch v := rawAddr.(type) { case *net.IPNet: //continue - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) case *net.IPAddr: - ip = v.IP + addr, _ = netip.AddrFromSlice(v.IP) } - nip, ok := netip.AddrFromSlice(ip) - if !ok { + if !addr.IsValid() { if l.Level >= logrus.DebugLevel { - l.WithField("localIp", ip).Debug("ip was invalid for netip") + l.WithField("localAddr", rawAddr).Debug("addr was invalid") } continue } - nip = nip.Unmap() + addr = addr.Unmap() //TODO: Filtering out link local for now, this is probably the most correct thing //TODO: Would be nice to filter out SLAAC MAC based ips as well - if nip.IsLoopback() == false && nip.IsLinkLocalUnicast() == false { - allow := allowList.Allow(nip) + if addr.IsLoopback() == false && addr.IsLinkLocalUnicast() == false { + isAllowed := allowList.Allow(addr) if l.Level >= logrus.TraceLevel { - l.WithField("localIp", nip).WithField("allow", allow).Trace("localAllowList.Allow") + l.WithField("localAddr", addr).WithField("allowed", isAllowed).Trace("localAllowList.Allow") } - if !allow { + if !isAllowed { continue } - ips = append(ips, nip) + finalAddrs = append(finalAddrs, addr) } } } - return ips + return finalAddrs } diff --git a/hostmap_test.go b/hostmap_test.go index 7e2feb810..e974340d0 100644 --- a/hostmap_test.go +++ b/hostmap_test.go @@ -11,17 +11,14 @@ import ( func TestHostMap_MakePrimary(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} hm.unlockedAddHostInfo(h4, f) hm.unlockedAddHostInfo(h3, f) @@ -29,7 +26,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.unlockedAddHostInfo(h1, f) // Make sure we go h1 -> h2 -> h3 -> h4 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -44,7 +41,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h3) // Make sure we go h3 -> h1 -> h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h3.localIndexId, prim.localIndexId) assert.Equal(t, h1.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -59,7 +56,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -74,7 +71,7 @@ func TestHostMap_MakePrimary(t *testing.T) { hm.MakePrimary(h4) // Make sure we go h4 -> h3 -> h1 -> h2 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -88,19 +85,16 @@ func TestHostMap_MakePrimary(t *testing.T) { func TestHostMap_DeleteHostInfo(t *testing.T) { l := test.NewLogger() - hm := newHostMap( - l, - netip.MustParsePrefix("10.0.0.1/24"), - ) + hm := newHostMap(l) f := &Interface{} - h1 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 1} - h2 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 2} - h3 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 3} - h4 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 4} - h5 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 5} - h6 := &HostInfo{vpnIp: netip.MustParseAddr("0.0.0.1"), localIndexId: 6} + h1 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 1} + h2 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 2} + h3 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 3} + h4 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 4} + h5 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 5} + h6 := &HostInfo{vpnAddrs: []netip.Addr{netip.MustParseAddr("0.0.0.1")}, localIndexId: 6} hm.unlockedAddHostInfo(h6, f) hm.unlockedAddHostInfo(h5, f) @@ -116,7 +110,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h) // Make sure we go h1 -> h2 -> h3 -> h4 -> h5 - prim := hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim := hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h1.localIndexId, prim.localIndexId) assert.Equal(t, h2.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -135,7 +129,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h1.next) // Make sure we go h2 -> h3 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h3.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -153,7 +147,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h3.next) // Make sure we go h2 -> h4 -> h5 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -169,7 +163,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h5.next) // Make sure we go h2 -> h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h2.localIndexId, prim.localIndexId) assert.Equal(t, h4.localIndexId, prim.next.localIndexId) assert.Nil(t, prim.prev) @@ -183,7 +177,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h2.next) // Make sure we only have h4 - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Equal(t, h4.localIndexId, prim.localIndexId) assert.Nil(t, prim.prev) assert.Nil(t, prim.next) @@ -195,7 +189,7 @@ func TestHostMap_DeleteHostInfo(t *testing.T) { assert.Nil(t, h4.next) // Make sure we have nil - prim = hm.QueryVpnIp(netip.MustParseAddr("0.0.0.1")) + prim = hm.QueryVpnAddr(netip.MustParseAddr("0.0.0.1")) assert.Nil(t, prim) } @@ -203,11 +197,7 @@ func TestHostMap_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) - hm := NewHostMapFromConfig( - l, - netip.MustParsePrefix("10.0.0.1/24"), - c, - ) + hm := NewHostMapFromConfig(l, c) toS := func(ipn []netip.Prefix) []string { var s []string diff --git a/hostmap_tester.go b/hostmap_tester.go index b2d1d1b5b..fe40c5334 100644 --- a/hostmap_tester.go +++ b/hostmap_tester.go @@ -9,8 +9,8 @@ import ( "net/netip" ) -func (i *HostInfo) GetVpnIp() netip.Addr { - return i.vpnIp +func (i *HostInfo) GetVpnAddrs() []netip.Addr { + return i.vpnAddrs } func (i *HostInfo) GetLocalIndex() uint32 { diff --git a/inside.go b/inside.go index 0ccd17909..9119560fe 100644 --- a/inside.go +++ b/inside.go @@ -20,14 +20,19 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } // Ignore local broadcast packets - if f.dropLocalBroadcast && fwPacket.RemoteIP == f.myBroadcastAddr { - return + if f.dropLocalBroadcast { + _, found := f.myBroadcastAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { + return + } } - if fwPacket.RemoteIP == f.myVpnNet.Addr() { + //TODO: seems like a huge bummer + _, found := f.myVpnAddrsTable.Lookup(fwPacket.RemoteAddr) + if found { // Immediately forward packets from self to self. // This should only happen on Darwin-based and FreeBSD hosts, which - // routes packets from the Nebula IP to the Nebula IP through the Nebula + // routes packets from the Nebula addr to the Nebula addr through the Nebula // TUN device. if immediatelyForwardToSelf { _, err := f.readers[q].Write(packet) @@ -36,25 +41,25 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet } } // Otherwise, drop. On linux, we should never see these packets - Linux - // routes packets from the nebula IP to the nebula IP through the loopback device. + // routes packets from the nebula addr to the nebula addr through the loopback device. return } // Ignore multicast packets - if f.dropMulticast && fwPacket.RemoteIP.IsMulticast() { + if f.dropMulticast && fwPacket.RemoteAddr.IsMulticast() { return } - hostinfo, ready := f.getOrHandshake(fwPacket.RemoteIP, func(hh *HandshakeHostInfo) { + hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics) }) if hostinfo == nil { f.rejectInside(packet, out, q) if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", fwPacket.RemoteIP). + f.l.WithField("vpnAddr", fwPacket.RemoteAddr). WithField("fwPacket", fwPacket). - Debugln("dropping outbound packet, vpnIp not in our CIDR or in unsafe routes") + Debugln("dropping outbound packet, vpnAddr not in our vpn networks or in unsafe networks") } return } @@ -117,21 +122,22 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo * f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q) } -func (f *Interface) Handshake(vpnIp netip.Addr) { - f.getOrHandshake(vpnIp, nil) +func (f *Interface) Handshake(vpnAddr netip.Addr) { + f.getOrHandshake(vpnAddr, nil) } -// getOrHandshake returns nil if the vpnIp is not routable. +// getOrHandshake returns nil if the vpnAddr is not routable. // If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel -func (f *Interface) getOrHandshake(vpnIp netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { - if !f.myVpnNet.Contains(vpnIp) { - vpnIp = f.inside.RouteFor(vpnIp) - if !vpnIp.IsValid() { +func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) { + _, found := f.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + vpnAddr = f.inside.RouteFor(vpnAddr) + if !vpnAddr.IsValid() { return nil, false } } - return f.handshakeManager.GetOrHandshake(vpnIp, cacheCallback) + return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback) } func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) { @@ -156,16 +162,16 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp f.sendNoMetrics(header.Message, st, hostinfo.ConnectionState, hostinfo, netip.AddrPort{}, p, nb, out, 0) } -// SendMessageToVpnIp handles real ip:port lookup and sends to the current best known address for vpnIp -func (f *Interface) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) { - hostInfo, ready := f.getOrHandshake(vpnIp, func(hh *HandshakeHostInfo) { +// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr +func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) { + hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) { hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics) }) if hostInfo == nil { if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", vpnIp). - Debugln("dropping SendMessageToVpnIp, vpnIp not in our CIDR or in unsafe routes") + f.l.WithField("vpnAddr", vpnAddr). + Debugln("dropping SendMessageToVpnAddr, vpnAddr not in our vpn networks or in unsafe routes") } return } @@ -285,14 +291,14 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType f.connectionManager.Out(hostinfo.localIndexId) // Query our LH if we haven't since the last time we've been rebound, this will cause the remote to punch against - // all our IPs and enable a faster roaming. + // all our addrs and enable a faster roaming. if t != header.CloseTunnel && hostinfo.lastRebindCount != f.rebindCount { //NOTE: there is an update hole if a tunnel isn't used and exactly 256 rebinds occur before the tunnel is // finally used again. This tunnel would eventually be torn down and recreated if this action didn't help. - f.lightHouse.QueryServer(hostinfo.vpnIp) + f.lightHouse.QueryServer(hostinfo.vpnAddrs[0]) hostinfo.lastRebindCount = f.rebindCount if f.l.Level >= logrus.DebugLevel { - f.l.WithField("vpnIp", hostinfo.vpnIp).Debug("Lighthouse update triggered for punch due to rebind counter") + f.l.WithField("vpnAddrs", hostinfo.vpnAddrs).Debug("Lighthouse update triggered for punch due to rebind counter") } } @@ -324,7 +330,7 @@ func (f *Interface) sendNoMetrics(t header.MessageType, st header.MessageSubType } else { // Try to send via a relay for _, relayIP := range hostinfo.relayState.CopyRelayIps() { - relayHostInfo, relay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relayIP) + relayHostInfo, relay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relayIP) if err != nil { hostinfo.relayState.DeleteRelay(relayIP) hostinfo.logger(f.l).WithField("relay", relayIP).WithError(err).Info("sendNoMetrics failed to find HostInfo") diff --git a/interface.go b/interface.go index 5d41a8714..9b9315697 100644 --- a/interface.go +++ b/interface.go @@ -2,17 +2,16 @@ package nebula import ( "context" - "encoding/binary" "errors" "fmt" "io" - "net" "net/netip" "os" "runtime" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -29,7 +28,6 @@ type InterfaceConfig struct { Outside udp.Conn Inside overlay.Device pki *PKI - Cipher string Firewall *Firewall ServeDns bool HandshakeManager *HandshakeManager @@ -53,25 +51,27 @@ type InterfaceConfig struct { } type Interface struct { - hostMap *HostMap - outside udp.Conn - inside overlay.Device - pki *PKI - cipher string - firewall *Firewall - connectionManager *connectionManager - handshakeManager *HandshakeManager - serveDns bool - createTime time.Time - lightHouse *LightHouse - myBroadcastAddr netip.Addr - myVpnNet netip.Prefix - dropLocalBroadcast bool - dropMulticast bool - routines int - disconnectInvalid atomic.Bool - closed atomic.Bool - relayManager *relayManager + hostMap *HostMap + outside udp.Conn + inside overlay.Device + pki *PKI + firewall *Firewall + connectionManager *connectionManager + handshakeManager *HandshakeManager + serveDns bool + createTime time.Time + lightHouse *LightHouse + myBroadcastAddrsTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr // A list of addresses assigned to us via our certificate + myVpnAddrsTable *bart.Table[struct{}] // A table of addresses assigned to us via our certificate + myVpnNetworks []netip.Prefix // A table of networks assigned to us via our certificate + myVpnNetworksTable *bart.Table[struct{}] // A table of networks assigned to us via our certificate + dropLocalBroadcast bool + dropMulticast bool + routines int + disconnectInvalid atomic.Bool + closed atomic.Bool + relayManager *relayManager tryPromoteEvery atomic.Uint32 reQueryEvery atomic.Uint32 @@ -103,9 +103,11 @@ type EncWriter interface { out []byte, nocopy bool, ) - SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, nb, out []byte) + SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) SendMessageToHostInfo(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) - Handshake(vpnIp netip.Addr) + Handshake(vpnAddr netip.Addr) + GetHostInfo(vpnAddr netip.Addr) *HostInfo + GetCertState() *CertState } type sendRecvErrorConfig uint8 @@ -116,10 +118,10 @@ const ( sendRecvErrorPrivate ) -func (s sendRecvErrorConfig) ShouldSendRecvError(ip netip.AddrPort) bool { +func (s sendRecvErrorConfig) ShouldSendRecvError(endpoint netip.AddrPort) bool { switch s { case sendRecvErrorPrivate: - return ip.Addr().IsPrivate() + return endpoint.Addr().IsPrivate() case sendRecvErrorAlways: return true case sendRecvErrorNever: @@ -156,27 +158,29 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { return nil, errors.New("no firewall rules") } - certificate := c.pki.GetCertState().Certificate - + cs := c.pki.getCertState() ifce := &Interface{ - pki: c.pki, - hostMap: c.HostMap, - outside: c.Outside, - inside: c.Inside, - cipher: c.Cipher, - firewall: c.Firewall, - serveDns: c.ServeDns, - handshakeManager: c.HandshakeManager, - createTime: time.Now(), - lightHouse: c.lightHouse, - dropLocalBroadcast: c.DropLocalBroadcast, - dropMulticast: c.DropMulticast, - routines: c.routines, - version: c.version, - writers: make([]udp.Conn, c.routines), - readers: make([]io.ReadWriteCloser, c.routines), - myVpnNet: certificate.Networks()[0], - relayManager: c.relayManager, + pki: c.pki, + hostMap: c.HostMap, + outside: c.Outside, + inside: c.Inside, + firewall: c.Firewall, + serveDns: c.ServeDns, + handshakeManager: c.HandshakeManager, + createTime: time.Now(), + lightHouse: c.lightHouse, + dropLocalBroadcast: c.DropLocalBroadcast, + dropMulticast: c.DropMulticast, + routines: c.routines, + version: c.version, + writers: make([]udp.Conn, c.routines), + readers: make([]io.ReadWriteCloser, c.routines), + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + myVpnAddrs: cs.myVpnAddrs, + myVpnAddrsTable: cs.myVpnAddrsTable, + myBroadcastAddrsTable: cs.myVpnBroadcastAddrsTable, + relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -190,14 +194,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - if ifce.myVpnNet.Addr().Is4() { - maskedAddr := certificate.Networks()[0].Masked() - addr := maskedAddr.Addr().As4() - mask := net.CIDRMask(maskedAddr.Bits(), maskedAddr.Addr().BitLen()) - binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) - ifce.myBroadcastAddr = netip.AddrFrom4(addr) - } - ifce.tryPromoteEvery.Store(c.tryPromoteEvery) ifce.reQueryEvery.Store(c.reQueryEvery) ifce.reQueryWait.Store(int64(c.reQueryWait)) @@ -218,7 +214,7 @@ func (f *Interface) activate() { f.l.WithError(err).Error("Failed to get udp listen address") } - f.l.WithField("interface", f.inside.Name()).WithField("network", f.inside.Cidr().String()). + f.l.WithField("interface", f.inside.Name()).WithField("networks", f.myVpnNetworks). WithField("build", f.version).WithField("udpAddr", addr). WithField("boringcrypto", boringEnabled()). Info("Nebula interface is active") @@ -259,16 +255,22 @@ func (f *Interface) listenOut(i int) { runtime.LockOSThread() var li udp.Conn - // TODO clean this up with a coherent interface for each outside connection if i > 0 { li = f.writers[i] } else { li = f.outside } + ctCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) lhh := f.lightHouse.NewRequestHandler() - conntrackCache := firewall.NewConntrackCacheTicker(f.conntrackCacheTimeout) - li.ListenOut(readOutsidePackets(f), lhHandleRequest(lhh, f), conntrackCache, i) + plaintext := make([]byte, udp.MTU) + h := &header.H{} + fwPacket := &firewall.Packet{} + nb := make([]byte, 12, 12) + + li.ListenOut(func(fromUdpAddr netip.AddrPort, payload []byte) { + f.readOutsidePackets(fromUdpAddr, nil, plaintext[:0], payload, h, fwPacket, lhh, nb, i, ctCache.Get(f.l)) + }) } func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { @@ -325,7 +327,7 @@ func (f *Interface) reloadFirewall(c *config.C) { return } - fw, err := NewFirewallFromConfig(f.l, f.pki.GetCertState().Certificate, c) + fw, err := NewFirewallFromConfig(f.l, f.pki.getCertState(), c) if err != nil { f.l.WithError(err).Error("Error while creating firewall during reload") return @@ -417,11 +419,20 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.getCertState().GetDefaultCertificate().NotAfter().Sub(time.Now()) / time.Second)) + //TODO: we should also report the default certificate version } } } +func (f *Interface) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return f.hostMap.QueryVpnAddr(vpnIp) +} + +func (f *Interface) GetCertState() *CertState { + return f.pki.getCertState() +} + func (f *Interface) Close() error { f.closed.Store(true) diff --git a/lighthouse.go b/lighthouse.go index 62f406560..59e894101 100644 --- a/lighthouse.go +++ b/lighthouse.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/netip" + "slices" "strconv" "sync" "sync/atomic" @@ -15,28 +16,28 @@ import ( "github.com/gaissmai/bart" "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/util" ) -//TODO: if a lighthouse doesn't have an answer, clients AGGRESSIVELY REQUERY.. why? handshake manager and/or getOrHandshake? -//TODO: nodes are roaming lighthouses, this is bad. How are they learning? - var ErrHostNotKnown = errors.New("host not known") type LightHouse struct { - //TODO: We need a timer wheel to kick out vpnIps that haven't reported in a long time + //TODO: We need a timer wheel to kick out vpnAddrs that haven't reported in a long time sync.RWMutex //Because we concurrently read and write to our maps ctx context.Context amLighthouse bool - myVpnNet netip.Prefix - punchConn udp.Conn - punchy *Punchy + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + punchConn udp.Conn + punchy *Punchy // Local cache of answers from light houses - // map of vpn Ip to answers + // map of vpn addr to answers addrMap map[netip.Addr]*RemoteList // filters remote addresses allowed for each host @@ -64,12 +65,12 @@ type LightHouse struct { advertiseAddrs atomic.Pointer[[]netip.AddrPort] - // IP's of relays that can be used by peers to access me + // Addr's of relays that can be used by peers to access me relaysForMe atomic.Pointer[[]netip.Addr] queryChan chan netip.Addr - calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnIp to []*calculatedRemote + calculatedRemotes atomic.Pointer[bart.Table[[]*calculatedRemote]] // Maps VpnAddr to []*calculatedRemote metrics *MessageMetrics metricHolepunchTx metrics.Counter @@ -78,7 +79,7 @@ type LightHouse struct { // NewLightHouseFromConfig will build a Lighthouse struct from the values provided in the config object // addrMap should be nil unless this is during a config reload -func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, myVpnNet netip.Prefix, pc udp.Conn, p *Punchy) (*LightHouse, error) { +func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, cs *CertState, pc udp.Conn, p *Punchy) (*LightHouse, error) { amLighthouse := c.GetBool("lighthouse.am_lighthouse", false) nebulaPort := uint32(c.GetInt("listen.port", 0)) if amLighthouse && nebulaPort == 0 { @@ -95,15 +96,16 @@ func NewLightHouseFromConfig(ctx context.Context, l *logrus.Logger, c *config.C, } h := LightHouse{ - ctx: ctx, - amLighthouse: amLighthouse, - myVpnNet: myVpnNet, - addrMap: make(map[netip.Addr]*RemoteList), - nebulaPort: nebulaPort, - punchConn: pc, - punchy: p, - queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), - l: l, + ctx: ctx, + amLighthouse: amLighthouse, + myVpnNetworks: cs.myVpnNetworks, + myVpnNetworksTable: cs.myVpnNetworksTable, + addrMap: make(map[netip.Addr]*RemoteList), + nebulaPort: nebulaPort, + punchConn: pc, + punchy: p, + queryChan: make(chan netip.Addr, c.GetUint32("handshakes.query_buffer", 64)), + l: l, } lighthouses := make(map[netip.Addr]struct{}) h.lighthouses.Store(&lighthouses) @@ -180,11 +182,11 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { return util.NewContextualError("Unable to parse lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - ips, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) + addrs, err := net.DefaultResolver.LookupNetIP(context.Background(), "ip", host) if err != nil { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, err) } - if len(ips) == 0 { + if len(addrs) == 0 { return util.NewContextualError("Unable to lookup lighthouse.advertise_addrs entry", m{"addr": rawAddr, "entry": i + 1}, nil) } @@ -197,15 +199,16 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { port = int(lh.nebulaPort) } - //TODO: we could technically insert all returned ips instead of just the first one if a dns lookup was used - ip := ips[0].Unmap() - if lh.myVpnNet.Contains(ip) { + //TODO: we could technically insert all returned addrs instead of just the first one if a dns lookup was used + addr := addrs[0].Unmap() + _, found := lh.myVpnNetworksTable.Lookup(addr) + if found { lh.l.WithField("addr", rawAddr).WithField("entry", i+1). Warn("Ignoring lighthouse.advertise_addrs report because it is within the nebula network range") continue } - advAddrs = append(advAddrs, netip.AddrPortFrom(ip, uint16(port))) + advAddrs = append(advAddrs, netip.AddrPortFrom(addr, uint16(port))) } lh.advertiseAddrs.Store(&advAddrs) @@ -275,8 +278,8 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { // Entries no longer present must have their (possible) background DNS goroutines stopped. if existingStaticList := lh.staticList.Load(); existingStaticList != nil { lh.RLock() - for staticVpnIp := range *existingStaticList { - if am, ok := lh.addrMap[staticVpnIp]; ok && am != nil { + for staticVpnAddr := range *existingStaticList { + if am, ok := lh.addrMap[staticVpnAddr]; ok && am != nil { am.hr.Cancel() } } @@ -333,11 +336,11 @@ func (lh *LightHouse) reload(c *config.C, initial bool) error { case false: relaysForMe := []netip.Addr{} for _, v := range c.GetStringSlice("relay.relays", nil) { - lh.l.WithField("relay", v).Info("Read relay from config") - configRIP, err := netip.ParseAddr(v) - //TODO: We could print the error here - if err == nil { + if err != nil { + lh.l.WithField("relay", v).WithError(err).Warn("Parse relay from config failed") + } else { + lh.l.WithField("relay", v).Info("Read relay from config") relaysForMe = append(relaysForMe, configRIP) } } @@ -355,14 +358,16 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } for i, host := range lhs { - ip, err := netip.ParseAddr(host) + addr, err := netip.ParseAddr(host) if err != nil { return util.NewContextualError("Unable to parse lighthouse host entry", m{"host": host, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(ip) { - return util.NewContextualError("lighthouse host is not in our subnet, invalid", m{"vpnIp": ip, "network": lh.myVpnNet}, nil) + + _, found := lh.myVpnNetworksTable.Lookup(addr) + if !found { + return util.NewContextualError("lighthouse host is not in our networks, invalid", m{"vpnAddr": addr, "networks": lh.myVpnNetworks}, nil) } - lhMap[ip] = struct{}{} + lhMap[addr] = struct{}{} } if !lh.amLighthouse && len(lhMap) == 0 { @@ -370,9 +375,9 @@ func (lh *LightHouse) parseLighthouses(c *config.C, lhMap map[netip.Addr]struct{ } staticList := lh.GetStaticHostList() - for lhIP, _ := range lhMap { - if _, ok := staticList[lhIP]; !ok { - return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhIP) + for lhAddr, _ := range lhMap { + if _, ok := staticList[lhAddr]; !ok { + return fmt.Errorf("lighthouse %s does not have a static_host_map entry", lhAddr) } } @@ -425,13 +430,14 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc i := 0 for k, v := range shm { - vpnIp, err := netip.ParseAddr(fmt.Sprintf("%v", k)) + vpnAddr, err := netip.ParseAddr(fmt.Sprintf("%v", k)) if err != nil { return util.NewContextualError("Unable to parse static_host_map entry", m{"host": k, "entry": i + 1}, err) } - if !lh.myVpnNet.Contains(vpnIp) { - return util.NewContextualError("static_host_map key is not in our subnet, invalid", m{"vpnIp": vpnIp, "network": lh.myVpnNet, "entry": i + 1}, nil) + _, found := lh.myVpnNetworksTable.Lookup(vpnAddr) + if !found { + return util.NewContextualError("static_host_map key is not in our network, invalid", m{"vpnAddr": vpnAddr, "networks": lh.myVpnNetworks, "entry": i + 1}, nil) } vals, ok := v.([]interface{}) @@ -443,7 +449,7 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc remoteAddrs = append(remoteAddrs, fmt.Sprintf("%v", v)) } - err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnIp, remoteAddrs, staticList) + err = lh.addStaticRemotes(i, d, network, lookupTimeout, vpnAddr, remoteAddrs, staticList) if err != nil { return err } @@ -453,12 +459,12 @@ func (lh *LightHouse) loadStaticMap(c *config.C, staticList map[netip.Addr]struc return nil } -func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { - if !lh.IsLighthouseIP(ip) { - lh.QueryServer(ip) +func (lh *LightHouse) Query(vpnAddr netip.Addr) *RemoteList { + if !lh.IsLighthouseAddr(vpnAddr) { + lh.QueryServer(vpnAddr) } lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { lh.RUnlock() return v } @@ -467,18 +473,18 @@ func (lh *LightHouse) Query(ip netip.Addr) *RemoteList { } // QueryServer is asynchronous so no reply should be expected -func (lh *LightHouse) QueryServer(ip netip.Addr) { - // Don't put lighthouse ips in the query channel because we can't query lighthouses about lighthouses - if lh.amLighthouse || lh.IsLighthouseIP(ip) { +func (lh *LightHouse) QueryServer(vpnAddr netip.Addr) { + // Don't put lighthouse addrs in the query channel because we can't query lighthouses about lighthouses + if lh.amLighthouse || lh.IsLighthouseAddr(vpnAddr) { return } - lh.queryChan <- ip + lh.queryChan <- vpnAddr } -func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { +func (lh *LightHouse) QueryCache(vpnAddrs []netip.Addr) *RemoteList { lh.RLock() - if v, ok := lh.addrMap[ip]; ok { + if v, ok := lh.addrMap[vpnAddrs[0]]; ok { lh.RUnlock() return v } @@ -487,24 +493,27 @@ func (lh *LightHouse) QueryCache(ip netip.Addr) *RemoteList { lh.Lock() defer lh.Unlock() // Add an entry if we don't already have one - return lh.unlockedGetRemoteList(ip) + return lh.unlockedGetRemoteList(vpnAddrs) } // queryAndPrepMessage is a lock helper on RemoteList, assisting the caller to build a lighthouse message containing -// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnIp +// details from the remote list. It looks for a hit in the addrMap and a hit in the RemoteList under the owner vpnAddr // If one is found then f() is called with proper locking, f() must return result of n.MarshalTo() -func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, error)) (bool, int, error) { +func (lh *LightHouse) queryAndPrepMessage(vpnAddr netip.Addr, f func(*cache) (int, error)) (bool, int, error) { lh.RLock() // Do we have an entry in the main cache? - if v, ok := lh.addrMap[vpnIp]; ok { + if v, ok := lh.addrMap[vpnAddr]; ok { // Swap lh lock for remote list lock v.RLock() defer v.RUnlock() lh.RUnlock() - // vpnIp should also be the owner here since we are a lighthouse. - c := v.cache[vpnIp] + // We may be asking about a non primary address so lets get the primary address + if slices.Contains(v.vpnAddrs, vpnAddr) { + vpnAddr = v.vpnAddrs[0] + } + c := v.cache[vpnAddr] // Make sure we have if c != nil { n, err := f(c) @@ -516,112 +525,140 @@ func (lh *LightHouse) queryAndPrepMessage(vpnIp netip.Addr, f func(*cache) (int, return false, 0, nil } -func (lh *LightHouse) DeleteVpnIp(vpnIp netip.Addr) { +func (lh *LightHouse) DeleteVpnAddrs(allVpnAddrs []netip.Addr) { // First we check the static mapping // and do nothing if it is there - if _, ok := lh.GetStaticHostList()[vpnIp]; ok { + if _, ok := lh.GetStaticHostList()[allVpnAddrs[0]]; ok { return } lh.Lock() - //l.Debugln(lh.addrMap) - delete(lh.addrMap, vpnIp) - - if lh.l.Level >= logrus.DebugLevel { - lh.l.Debugf("deleting %s from lighthouse.", vpnIp) + rm, ok := lh.addrMap[allVpnAddrs[0]] + if ok { + for _, addr := range allVpnAddrs { + srm := lh.addrMap[addr] + if srm == rm { + delete(lh.addrMap, addr) + if lh.l.Level >= logrus.DebugLevel { + lh.l.Debugf("deleting %s from lighthouse.", addr) + } + } + } } - lh.Unlock() } -// AddStaticRemote adds a static host entry for vpnIp as ourselves as the owner +// AddStaticRemote adds a static host entry for vpnAddr as ourselves as the owner // We are the owner because we don't want a lighthouse server to advertise for static hosts it was configured with // And we don't want a lighthouse query reply to interfere with our learned cache if we are a client // NOTE: this function should not interact with any hot path objects, like lh.staticList, the caller should handle it -func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnIp netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { +func (lh *LightHouse) addStaticRemotes(i int, d time.Duration, network string, timeout time.Duration, vpnAddr netip.Addr, toAddrs []string, staticList map[netip.Addr]struct{}) error { lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() ctx := lh.ctx lh.Unlock() hr, err := NewHostnameResults(ctx, lh.l, d, network, timeout, toAddrs, func() { - // This callback runs whenever the DNS hostname resolver finds a different set of IP's + // This callback runs whenever the DNS hostname resolver finds a different set of addr's // in its resolution for hostnames. am.Lock() defer am.Unlock() am.shouldRebuild = true }) if err != nil { - return util.NewContextualError("Static host address could not be parsed", m{"vpnIp": vpnIp, "entry": i + 1}, err) + return util.NewContextualError("Static host address could not be parsed", m{"vpnAddr": vpnAddr, "entry": i + 1}, err) } am.unlockedSetHostnamesResults(hr) - for _, addrPort := range hr.GetIPs() { - if !lh.shouldAdd(vpnIp, addrPort.Addr()) { + for _, addrPort := range hr.GetAddrs() { + if !lh.shouldAdd(vpnAddr, addrPort.Addr()) { continue } switch { case addrPort.Addr().Is4(): - am.unlockedPrependV4(lh.myVpnNet.Addr(), NewIp4AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV4(lh.myVpnNetworks[0].Addr(), netAddrToProtoV4AddrPort(addrPort.Addr(), addrPort.Port())) case addrPort.Addr().Is6(): - am.unlockedPrependV6(lh.myVpnNet.Addr(), NewIp6AndPortFromNetIP(addrPort.Addr(), addrPort.Port())) + am.unlockedPrependV6(lh.myVpnNetworks[0].Addr(), netAddrToProtoV6AddrPort(addrPort.Addr(), addrPort.Port())) } } // Mark it as static in the caller provided map - staticList[vpnIp] = struct{}{} + staticList[vpnAddr] = struct{}{} return nil } // addCalculatedRemotes adds any calculated remotes based on the // lighthouse.calculated_remotes configuration. It returns true if any // calculated remotes were added -func (lh *LightHouse) addCalculatedRemotes(vpnIp netip.Addr) bool { +func (lh *LightHouse) addCalculatedRemotes(vpnAddr netip.Addr) bool { tree := lh.getCalculatedRemotes() if tree == nil { return false } - calculatedRemotes, ok := tree.Lookup(vpnIp) + calculatedRemotes, ok := tree.Lookup(vpnAddr) if !ok { return false } - var calculated []*Ip4AndPort + var calculatedV4 []*V4AddrPort + var calculatedV6 []*V6AddrPort for _, cr := range calculatedRemotes { - c := cr.Apply(vpnIp) - if c != nil { - calculated = append(calculated, c) + if vpnAddr.Is4() { + c := cr.ApplyV4(vpnAddr) + if c != nil { + calculatedV4 = append(calculatedV4, c) + } + } else if vpnAddr.Is6() { + c := cr.ApplyV6(vpnAddr) + if c != nil { + calculatedV6 = append(calculatedV6, c) + } } } lh.Lock() - am := lh.unlockedGetRemoteList(vpnIp) + am := lh.unlockedGetRemoteList([]netip.Addr{vpnAddr}) am.Lock() defer am.Unlock() lh.Unlock() - am.unlockedSetV4(lh.myVpnNet.Addr(), vpnIp, calculated, lh.unlockedShouldAddV4) + if len(calculatedV4) > 0 { + am.unlockedSetV4(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV4, lh.unlockedShouldAddV4) + } + + if len(calculatedV6) > 0 { + am.unlockedSetV6(lh.myVpnNetworks[0].Addr(), vpnAddr, calculatedV6, lh.unlockedShouldAddV6) + } - return len(calculated) > 0 + return len(calculatedV4) > 0 || len(calculatedV6) > 0 } -// unlockedGetRemoteList assumes you have the lh lock -func (lh *LightHouse) unlockedGetRemoteList(vpnIp netip.Addr) *RemoteList { - am, ok := lh.addrMap[vpnIp] +// unlockedGetRemoteList +// assumes you have the lh lock +func (lh *LightHouse) unlockedGetRemoteList(allAddrs []netip.Addr) *RemoteList { + am, ok := lh.addrMap[allAddrs[0]] if !ok { - am = NewRemoteList(func(a netip.Addr) bool { return lh.shouldAdd(vpnIp, a) }) - lh.addrMap[vpnIp] = am + am = NewRemoteList(allAddrs, func(a netip.Addr) bool { return lh.shouldAdd(allAddrs[0], a) }) + for _, addr := range allAddrs { + lh.addrMap[addr] = am + } } return am } -func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { - allow := lh.GetRemoteAllowList().Allow(vpnIp, to) +func (lh *LightHouse) shouldAdd(vpnAddr netip.Addr, to netip.Addr) bool { + allow := lh.GetRemoteAllowList().Allow(vpnAddr, to) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", to).WithField("allow", allow). + Trace("remoteAllowList.Allow") + } + if !allow { + return false } - if !allow || lh.myVpnNet.Contains(to) { + + _, found := lh.myVpnNetworksTable.Lookup(to) + if found { return false } @@ -629,14 +666,20 @@ func (lh *LightHouse) shouldAdd(vpnIp netip.Addr, to netip.Addr) bool { } // unlockedShouldAddV4 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool { - ip := AddrPortFromIp4AndPort(to) - allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) +func (lh *LightHouse) unlockedShouldAddV4(vpnAddr netip.Addr, to *V4AddrPort) bool { + udpAddr := protoV4AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", vpnIp).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") + } + + if !allow { + return false } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { return false } @@ -644,78 +687,43 @@ func (lh *LightHouse) unlockedShouldAddV4(vpnIp netip.Addr, to *Ip4AndPort) bool } // unlockedShouldAddV6 checks if to is allowed by our allow list -func (lh *LightHouse) unlockedShouldAddV6(vpnIp netip.Addr, to *Ip6AndPort) bool { - ip := AddrPortFromIp6AndPort(to) - allow := lh.GetRemoteAllowList().Allow(vpnIp, ip.Addr()) +func (lh *LightHouse) unlockedShouldAddV6(vpnAddr netip.Addr, to *V6AddrPort) bool { + udpAddr := protoV6AddrPortToNetAddrPort(to) + allow := lh.GetRemoteAllowList().Allow(vpnAddr, udpAddr.Addr()) if lh.l.Level >= logrus.TraceLevel { - lh.l.WithField("remoteIp", lhIp6ToIp(to)).WithField("allow", allow).Trace("remoteAllowList.Allow") + lh.l.WithField("vpnAddr", vpnAddr).WithField("udpAddr", udpAddr).WithField("allow", allow). + Trace("remoteAllowList.Allow") } - if !allow || lh.myVpnNet.Contains(ip.Addr()) { + if !allow { return false } - return true -} + _, found := lh.myVpnNetworksTable.Lookup(udpAddr.Addr()) + if found { + return false + } -func lhIp6ToIp(v *Ip6AndPort) net.IP { - ip := make(net.IP, 16) - binary.BigEndian.PutUint64(ip[:8], v.Hi) - binary.BigEndian.PutUint64(ip[8:], v.Lo) - return ip + return true } -func (lh *LightHouse) IsLighthouseIP(vpnIp netip.Addr) bool { - if _, ok := lh.GetLighthouses()[vpnIp]; ok { +func (lh *LightHouse) IsLighthouseAddr(vpnAddr netip.Addr) bool { + if _, ok := lh.GetLighthouses()[vpnAddr]; ok { return true } return false } -func NewLhQueryByInt(vpnIp netip.Addr) *NebulaMeta { - if vpnIp.Is6() { - //TODO: need to support ipv6 - panic("ipv6 is not yet supported") - } - - b := vpnIp.As4() - return &NebulaMeta{ - Type: NebulaMeta_HostQuery, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - }, - } -} - -func AddrPortFromIp4AndPort(ip *Ip4AndPort) netip.AddrPort { - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], ip.Ip) - return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ip.Port)) -} - -func AddrPortFromIp6AndPort(ip *Ip6AndPort) netip.AddrPort { - b := [16]byte{} - binary.BigEndian.PutUint64(b[:8], ip.Hi) - binary.BigEndian.PutUint64(b[8:], ip.Lo) - return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ip.Port)) -} - -func NewIp4AndPortFromNetIP(ip netip.Addr, port uint16) *Ip4AndPort { - v4Addr := ip.As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), - Port: uint32(port), - } -} - -// TODO: IPV6-WORK we can delete some more of these -func NewIp6AndPortFromNetIP(ip netip.Addr, port uint16) *Ip6AndPort { - ip6Addr := ip.As16() - return &Ip6AndPort{ - Hi: binary.BigEndian.Uint64(ip6Addr[:8]), - Lo: binary.BigEndian.Uint64(ip6Addr[8:]), - Port: uint32(port), +// TODO: IsLighthouseAddr should be sufficient, we just need to update the vpnAddrs for lighthouses after a handshake +// so that we know all the lighthouse vpnAddrs, not just the ones we were configured to talk to initially +func (lh *LightHouse) IsAnyLighthouseAddr(vpnAddr []netip.Addr) bool { + l := lh.GetLighthouses() + for _, a := range vpnAddr { + if _, ok := l[a]; ok { + return true + } } + return false } func (lh *LightHouse) startQueryWorker() { @@ -731,31 +739,85 @@ func (lh *LightHouse) startQueryWorker() { select { case <-lh.ctx.Done(): return - case ip := <-lh.queryChan: - lh.innerQueryServer(ip, nb, out) + case addr := <-lh.queryChan: + lh.innerQueryServer(addr, nb, out) } } }() } -func (lh *LightHouse) innerQueryServer(ip netip.Addr, nb, out []byte) { - if lh.IsLighthouseIP(ip) { +func (lh *LightHouse) innerQueryServer(addr netip.Addr, nb, out []byte) { + if lh.IsLighthouseAddr(addr) { return } - // Send a query to the lighthouses and hope for the best next time - query, err := NewLhQueryByInt(ip).Marshal() - if err != nil { - lh.l.WithError(err).WithField("vpnIp", ip).Error("Failed to marshal lighthouse query payload") - return + msg := &NebulaMeta{ + Type: NebulaMeta_HostQuery, + Details: &NebulaMetaDetails{}, } + var v1Query, v2Query []byte + var err error + var v cert.Version + queried := 0 lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostQuery, int64(len(lighthouses))) - for n := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, n, query, nb, out) + for lhVpnAddr := range lighthouses { + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + + if v == cert.Version1 { + if !addr.Is4() { + lh.l.WithField("queryVpnAddr", addr).WithField("lighthouseAddr", lhVpnAddr). + Error("Can't query lighthouse for v6 address using a v1 protocol") + continue + } + + if v1Query == nil { + b := addr.As4() + msg.Details.VpnAddr = nil + msg.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + + v1Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v1 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Query, nb, out) + queried++ + + } else if v == cert.Version2 { + if v2Query == nil { + msg.Details.OldVpnAddr = 0 + msg.Details.VpnAddr = netAddrToProtoAddr(addr) + + v2Query, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("queryVpnAddr", addr). + WithField("lighthouseAddr", lhVpnAddr). + Error("Failed to marshal lighthouse v2 query payload") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Query, nb, out) + queried++ + + } else { + lh.l.Debugf("Can not query lighthouse for %v using unknown protocol version: %v", addr, v) + continue + } } + + lh.metricTx(NebulaMeta_HostQuery, int64(queried)) } func (lh *LightHouse) StartUpdateWorker() { @@ -785,65 +847,120 @@ func (lh *LightHouse) StartUpdateWorker() { } func (lh *LightHouse) SendUpdate() { - var v4 []*Ip4AndPort - var v6 []*Ip6AndPort + var v4 []*V4AddrPort + var v6 []*V6AddrPort for _, e := range lh.GetAdvertiseAddrs() { if e.Addr().Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e.Addr(), e.Port())) + v4 = append(v4, netAddrToProtoV4AddrPort(e.Addr(), e.Port())) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e.Addr(), e.Port())) + v6 = append(v6, netAddrToProtoV6AddrPort(e.Addr(), e.Port())) } } lal := lh.GetLocalAllowList() - for _, e := range localIps(lh.l, lal) { - if lh.myVpnNet.Contains(e) { + for _, e := range localAddrs(lh.l, lal) { + _, found := lh.myVpnNetworksTable.Lookup(e) + if found { continue } - // Only add IPs that aren't my VPN/tun IP + // Only add addrs that aren't my VPN/tun networks if e.Is4() { - v4 = append(v4, NewIp4AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v4 = append(v4, netAddrToProtoV4AddrPort(e, uint16(lh.nebulaPort))) } else { - v6 = append(v6, NewIp6AndPortFromNetIP(e, uint16(lh.nebulaPort))) + v6 = append(v6, netAddrToProtoV6AddrPort(e, uint16(lh.nebulaPort))) } } - var relays []uint32 - for _, r := range lh.GetRelaysForMe() { - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := r.As4() - relays = append(relays, binary.BigEndian.Uint32(b[:])) - } + nb := make([]byte, 12, 12) + out := make([]byte, mtu) - //TODO: IPV6-WORK both relays and vpnip need ipv6 support - b := lh.myVpnNet.Addr().As4() + var v1Update, v2Update []byte + var err error + updated := 0 + lighthouses := lh.GetLighthouses() - m := &NebulaMeta{ - Type: NebulaMeta_HostUpdateNotification, - Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(b[:]), - Ip4AndPorts: v4, - Ip6AndPorts: v6, - RelayVpnIp: relays, - }, - } + for lhVpnAddr := range lighthouses { + var v cert.Version + hi := lh.ifce.GetHostInfo(lhVpnAddr) + if hi != nil { + v = hi.ConnectionState.myCert.Version() + } else { + v = lh.ifce.GetCertState().defaultVersion + } + if v == cert.Version1 { + if v1Update == nil { + if !lh.myVpnNetworks[0].Addr().Is4() { + lh.l.WithField("lighthouseAddr", lhVpnAddr). + Warn("cannot update lighthouse using v1 protocol without an IPv4 address") + continue + } + var relays []uint32 + for _, r := range lh.GetRelaysForMe() { + if !r.Is4() { + continue + } + b := r.As4() + relays = append(relays, binary.BigEndian.Uint32(b[:])) + } + b := lh.myVpnNetworks[0].Addr().As4() + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + OldRelayVpnAddrs: relays, + OldVpnAddr: binary.BigEndian.Uint32(b[:]), + }, + } - lighthouses := lh.GetLighthouses() - lh.metricTx(NebulaMeta_HostUpdateNotification, int64(len(lighthouses))) - nb := make([]byte, 12, 12) - out := make([]byte, mtu) + v1Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v1 update") + continue + } + } - mm, err := m.Marshal() - if err != nil { - lh.l.WithError(err).Error("Error while marshaling for lighthouse update") - return - } + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v1Update, nb, out) + updated++ + + } else if v == cert.Version2 { + if v2Update == nil { + var relays []*Addr + for _, r := range lh.GetRelaysForMe() { + relays = append(relays, netAddrToProtoAddr(r)) + } + + msg := NebulaMeta{ + Type: NebulaMeta_HostUpdateNotification, + Details: &NebulaMetaDetails{ + V4AddrPorts: v4, + V6AddrPorts: v6, + RelayVpnAddrs: relays, + VpnAddr: netAddrToProtoAddr(lh.myVpnNetworks[0].Addr()), + }, + } + + v2Update, err = msg.Marshal() + if err != nil { + lh.l.WithError(err).WithField("lighthouseAddr", lhVpnAddr). + Error("Error while marshaling for lighthouse v2 update") + continue + } + } + + lh.ifce.SendMessageToVpnAddr(header.LightHouse, 0, lhVpnAddr, v2Update, nb, out) + updated++ - for vpnIp := range lighthouses { - lh.ifce.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, mm, nb, out) + } else { + lh.l.Debugf("Can not update lighthouse using unknown protocol version: %v", v) + continue + } } + + lh.metricTx(NebulaMeta_HostUpdateNotification, int64(updated)) } type LightHouseHandler struct { @@ -886,34 +1003,30 @@ func (lhh *LightHouseHandler) resetMeta() *NebulaMeta { lhh.meta.Reset() // Keep the array memory around - details.Ip4AndPorts = details.Ip4AndPorts[:0] - details.Ip6AndPorts = details.Ip6AndPorts[:0] - details.RelayVpnIp = details.RelayVpnIp[:0] + details.V4AddrPorts = details.V4AddrPorts[:0] + details.V6AddrPorts = details.V6AddrPorts[:0] + details.RelayVpnAddrs = details.RelayVpnAddrs[:0] + details.OldRelayVpnAddrs = details.OldRelayVpnAddrs[:0] + //TODO: these are unfortunate + details.OldVpnAddr = 0 + details.VpnAddr = nil lhh.meta.Details = details return lhh.meta } -func lhHandleRequest(lhh *LightHouseHandler, f *Interface) udp.LightHouseHandlerFunc { - return func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) { - lhh.HandleRequest(rAddr, vpnIp, p, f) - } -} - -func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte, w EncWriter) { +func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, fromVpnAddrs []netip.Addr, p []byte, w EncWriter) { n := lhh.resetMeta() err := n.Unmarshal(p) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Failed to unmarshal lighthouse packet") - //TODO: send recv_error? return } if n.Details == nil { - lhh.l.WithField("vpnIp", vpnIp).WithField("udpAddr", rAddr). + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("udpAddr", rAddr). Error("Invalid lighthouse update") - //TODO: send recv_error? return } @@ -921,24 +1034,24 @@ func (lhh *LightHouseHandler) HandleRequest(rAddr netip.AddrPort, vpnIp netip.Ad switch n.Type { case NebulaMeta_HostQuery: - lhh.handleHostQuery(n, vpnIp, rAddr, w) + lhh.handleHostQuery(n, fromVpnAddrs, rAddr, w) case NebulaMeta_HostQueryReply: - lhh.handleHostQueryReply(n, vpnIp) + lhh.handleHostQueryReply(n, fromVpnAddrs) case NebulaMeta_HostUpdateNotification: - lhh.handleHostUpdateNotification(n, vpnIp, w) + lhh.handleHostUpdateNotification(n, fromVpnAddrs, w) case NebulaMeta_HostMovedNotification: case NebulaMeta_HostPunchNotification: - lhh.handleHostPunchNotification(n, vpnIp, w) + lhh.handleHostPunchNotification(n, fromVpnAddrs, w) case NebulaMeta_HostUpdateNotificationAck: // noop } } -func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, addr netip.AddrPort, w EncWriter) { +func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, fromVpnAddrs []netip.Addr, addr netip.AddrPort, w EncWriter) { // Exit if we don't answer queries if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { @@ -947,21 +1060,38 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a return } - //TODO: we can DRY this further - reqVpnIp := n.Details.VpnIp - - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) + useVersion := cert.Version1 + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + useVersion = 1 + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = 2 + } else { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("from", fromVpnAddrs).WithField("details", n.Details).Debugln("Dropping malformed HostQuery") + } + return + } //TODO: Maybe instead of marshalling into n we marshal into a new `r` to not nuke our current request data - found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnIp, func(c *cache) (int, error) { + found, ln, err := lhh.lh.queryAndPrepMessage(queryVpnAddr, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostQueryReply - n.Details.VpnIp = reqVpnIp + if useVersion == cert.Version1 { + if !queryVpnAddr.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := queryVpnAddr.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else { + n.Details.VpnAddr = netAddrToProtoAddr(queryVpnAddr) + } - lhh.coalesceAnswers(c, n) + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -971,21 +1101,51 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host query reply") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host query reply") return } lhh.lh.metricTx(NebulaMeta_HostQueryReply, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) + + lhh.sendHostPunchNotification(n, fromVpnAddrs, queryVpnAddr, w) +} - // This signals the other side to punch some zero byte udp packets - found, ln, err = lhh.lh.queryAndPrepMessage(vpnIp, func(c *cache) (int, error) { +// sendHostPunchNotification signals the other side to punch some zero byte udp packets +func (lhh *LightHouseHandler) sendHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, punchNotifDest netip.Addr, w EncWriter) { + whereToPunch := fromVpnAddrs[0] + found, ln, err := lhh.lh.queryAndPrepMessage(whereToPunch, func(c *cache) (int, error) { n = lhh.resetMeta() n.Type = NebulaMeta_HostPunchNotification - //TODO: IPV6-WORK - b = vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(b[:]) - lhh.coalesceAnswers(c, n) + targetHI := lhh.lh.ifce.GetHostInfo(punchNotifDest) + var useVersion cert.Version + if targetHI == nil { + useVersion = lhh.lh.ifce.GetCertState().defaultVersion + } else { + crt := targetHI.GetCert().Certificate + useVersion = crt.Version() + // we can only retarget if we have a hostinfo + newDest, ok := findNetworkUnion(crt.Networks(), fromVpnAddrs) + if ok { + whereToPunch = newDest + } else { + //TODO this means the destination will have no addresses in common with the punch-ee + //choosing to do nothing for now, but maybe we return an error? + } + } + + if useVersion == cert.Version1 { + if !whereToPunch.Is4() { + return 0, fmt.Errorf("invalid vpn addr for v1 handleHostQuery") + } + b := whereToPunch.As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(b[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(whereToPunch) + } else { + return 0, errors.New("unsupported version") + } + lhh.coalesceAnswers(useVersion, c, n) return n.MarshalTo(lhh.pb) }) @@ -995,139 +1155,168 @@ func (lhh *LightHouseHandler) handleHostQuery(n *NebulaMeta, vpnIp netip.Addr, a } if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host was queried for") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host was queried for") return } lhh.lh.metricTx(NebulaMeta_HostPunchNotification, 1) - - //TODO: IPV6-WORK - binary.BigEndian.PutUint32(b[:], reqVpnIp) - sendTo := netip.AddrFrom4(b) - w.SendMessageToVpnIp(header.LightHouse, 0, sendTo, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, punchNotifDest, lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) coalesceAnswers(c *cache, n *NebulaMeta) { +func (lhh *LightHouseHandler) coalesceAnswers(v cert.Version, c *cache, n *NebulaMeta) { if c.v4 != nil { if c.v4.learned != nil { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.learned) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.learned) } if c.v4.reported != nil && len(c.v4.reported) > 0 { - n.Details.Ip4AndPorts = append(n.Details.Ip4AndPorts, c.v4.reported...) + n.Details.V4AddrPorts = append(n.Details.V4AddrPorts, c.v4.reported...) } } if c.v6 != nil { if c.v6.learned != nil { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.learned) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.learned) } if c.v6.reported != nil && len(c.v6.reported) > 0 { - n.Details.Ip6AndPorts = append(n.Details.Ip6AndPorts, c.v6.reported...) + n.Details.V6AddrPorts = append(n.Details.V6AddrPorts, c.v6.reported...) } } if c.relay != nil { - //TODO: IPV6-WORK - relays := make([]uint32, len(c.relay.relay)) - b := [4]byte{} - for i, _ := range relays { - b = c.relay.relay[i].As4() - relays[i] = binary.BigEndian.Uint32(b[:]) + if v == cert.Version1 { + b := [4]byte{} + for _, r := range c.relay.relay { + if !r.Is4() { + continue + } + + b = r.As4() + n.Details.OldRelayVpnAddrs = append(n.Details.OldRelayVpnAddrs, binary.BigEndian.Uint32(b[:])) + } + + } else if v == cert.Version2 { + for _, r := range c.relay.relay { + n.Details.RelayVpnAddrs = append(n.Details.RelayVpnAddrs, netAddrToProtoAddr(r)) + } + + } else { + panic("unsupported version") } - n.Details.RelayVpnIp = append(n.Details.RelayVpnIp, relays...) } } -func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, vpnIp netip.Addr) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostQueryReply(n *NebulaMeta, fromVpnAddrs []netip.Addr) { + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } lhh.lh.Lock() - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - certVpnIp := netip.AddrFrom4(b) - am := lhh.lh.unlockedGetRemoteList(certVpnIp) + + var certVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + certVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + certVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + relays := n.Details.GetRelays() + + am := lhh.lh.unlockedGetRemoteList([]netip.Addr{certVpnAddr}) am.Lock() lhh.lh.Unlock() - //TODO: IPV6-WORK - am.unlockedSetV4(vpnIp, certVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, certVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) - } - am.unlockedSetRelay(vpnIp, certVpnIp, relays) + am.unlockedSetV4(fromVpnAddrs[0], certVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], certVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() // Non-blocking attempt to trigger, skip if it would block select { - case lhh.lh.handshakeTrigger <- certVpnIp: + case lhh.lh.handshakeTrigger <- certVpnAddr: default: } } -func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { +func (lhh *LightHouseHandler) handleHostUpdateNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { if !lhh.lh.amLighthouse { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", vpnIp) + lhh.l.Debugln("I am not a lighthouse, do not take host updates: ", fromVpnAddrs) + } + return + } + + var detailsVpnAddr netip.Addr + useVersion := cert.Version1 + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + detailsVpnAddr = netip.AddrFrom4(b) + useVersion = cert.Version1 + } else if n.Details.VpnAddr != nil { + detailsVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + useVersion = cert.Version2 + } else { + if lhh.l.Level >= logrus.DebugLevel { + lhh.l.WithField("details", n.Details).Debugf("dropping invalid HostUpdateNotification") } return } + //todo hosts with only v2 certs cannot provide their ipv6 addr when contacting the lighthouse via v4? + //todo why do we care about the vpnAddr in the packet? We know where it came from, right? //Simple check that the host sent this not someone else - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - detailsVpnIp := netip.AddrFrom4(b) - if detailsVpnIp != vpnIp { + if !slices.Contains(fromVpnAddrs, detailsVpnAddr) { if lhh.l.Level >= logrus.DebugLevel { - lhh.l.WithField("vpnIp", vpnIp).WithField("answer", detailsVpnIp).Debugln("Host sent invalid update") + lhh.l.WithField("vpnAddrs", fromVpnAddrs).WithField("answer", detailsVpnAddr).Debugln("Host sent invalid update") } return } + relays := n.Details.GetRelays() + lhh.lh.Lock() - am := lhh.lh.unlockedGetRemoteList(vpnIp) + am := lhh.lh.unlockedGetRemoteList(fromVpnAddrs) am.Lock() lhh.lh.Unlock() - am.unlockedSetV4(vpnIp, detailsVpnIp, n.Details.Ip4AndPorts, lhh.lh.unlockedShouldAddV4) - am.unlockedSetV6(vpnIp, detailsVpnIp, n.Details.Ip6AndPorts, lhh.lh.unlockedShouldAddV6) - - //TODO: IPV6-WORK - relays := make([]netip.Addr, len(n.Details.RelayVpnIp)) - for i, _ := range n.Details.RelayVpnIp { - binary.BigEndian.PutUint32(b[:], n.Details.RelayVpnIp[i]) - relays[i] = netip.AddrFrom4(b) - } - am.unlockedSetRelay(vpnIp, detailsVpnIp, relays) + am.unlockedSetV4(fromVpnAddrs[0], detailsVpnAddr, n.Details.V4AddrPorts, lhh.lh.unlockedShouldAddV4) + am.unlockedSetV6(fromVpnAddrs[0], detailsVpnAddr, n.Details.V6AddrPorts, lhh.lh.unlockedShouldAddV6) + am.unlockedSetRelay(fromVpnAddrs[0], relays) am.Unlock() n = lhh.resetMeta() n.Type = NebulaMeta_HostUpdateNotificationAck - //TODO: IPV6-WORK - vpnIpB := vpnIp.As4() - n.Details.VpnIp = binary.BigEndian.Uint32(vpnIpB[:]) - ln, err := n.MarshalTo(lhh.pb) + if useVersion == cert.Version1 { + if !fromVpnAddrs[0].Is4() { + lhh.l.WithField("vpnAddrs", fromVpnAddrs).Error("Can not send HostUpdateNotificationAck for a ipv6 vpn ip in a v1 message") + return + } + vpnAddrB := fromVpnAddrs[0].As4() + n.Details.OldVpnAddr = binary.BigEndian.Uint32(vpnAddrB[:]) + } else if useVersion == cert.Version2 { + n.Details.VpnAddr = netAddrToProtoAddr(fromVpnAddrs[0]) + } else { + lhh.l.WithField("useVersion", useVersion).Error("invalid protocol version") + return + } + ln, err := n.MarshalTo(lhh.pb) if err != nil { - lhh.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to marshal lighthouse host update ack") + lhh.l.WithError(err).WithField("vpnAddrs", fromVpnAddrs).Error("Failed to marshal lighthouse host update ack") return } lhh.lh.metricTx(NebulaMeta_HostUpdateNotificationAck, 1) - w.SendMessageToVpnIp(header.LightHouse, 0, vpnIp, lhh.pb[:ln], lhh.nb, lhh.out[:0]) + w.SendMessageToVpnAddr(header.LightHouse, 0, fromVpnAddrs[0], lhh.pb[:ln], lhh.nb, lhh.out[:0]) } -func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp netip.Addr, w EncWriter) { - if !lhh.lh.IsLighthouseIP(vpnIp) { +func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, fromVpnAddrs []netip.Addr, w EncWriter) { + //It's possible the lighthouse is communicating with us using a non primary vpn addr, + //which means we need to compare all fromVpnAddrs against all configured lighthouse vpn addrs. + //maybe one day we'll have a better idea, if it matters. + if !lhh.lh.IsAnyLighthouseAddr(fromVpnAddrs) { return } @@ -1144,39 +1333,123 @@ func (lhh *LightHouseHandler) handleHostPunchNotification(n *NebulaMeta, vpnIp n }() if lhh.l.Level >= logrus.DebugLevel { - //TODO: lacking the ip we are actually punching on, old: l.Debugf("Punching %s on %d for %s", IntIp(a.Ip), a.Port, IntIp(n.Details.VpnIp)) - //TODO: IPV6-WORK, make this debug line not suck - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - lhh.l.Debugf("Punching on %d for %v", vpnPeer.Port(), netip.AddrFrom4(b)) + var logVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + logVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + logVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + lhh.l.Debugf("Punching on %v for %v", vpnPeer, logVpnAddr) } } - for _, a := range n.Details.Ip4AndPorts { - punch(AddrPortFromIp4AndPort(a)) + for _, a := range n.Details.V4AddrPorts { + punch(protoV4AddrPortToNetAddrPort(a)) } - for _, a := range n.Details.Ip6AndPorts { - punch(AddrPortFromIp6AndPort(a)) + for _, a := range n.Details.V6AddrPorts { + punch(protoV6AddrPortToNetAddrPort(a)) } // This sends a nebula test packet to the host trying to contact us. In the case // of a double nat or other difficult scenario, this may help establish // a tunnel. if lhh.lh.punchy.GetRespond() { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], n.Details.VpnIp) - queryVpnIp := netip.AddrFrom4(b) + var queryVpnAddr netip.Addr + if n.Details.OldVpnAddr != 0 { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], n.Details.OldVpnAddr) + queryVpnAddr = netip.AddrFrom4(b) + } else if n.Details.VpnAddr != nil { + queryVpnAddr = protoAddrToNetAddr(n.Details.VpnAddr) + } + go func() { time.Sleep(lhh.lh.punchy.GetRespondDelay()) if lhh.l.Level >= logrus.DebugLevel { - lhh.l.Debugf("Sending a nebula test packet to vpn ip %s", queryVpnIp) + lhh.l.Debugf("Sending a nebula test packet to vpn addr %s", queryVpnAddr) } //NOTE: we have to allocate a new output buffer here since we are spawning a new goroutine // for each punchBack packet. We should move this into a timerwheel or a single goroutine // managed by a channel. - w.SendMessageToVpnIp(header.Test, header.TestRequest, queryVpnIp, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) + w.SendMessageToVpnAddr(header.Test, header.TestRequest, queryVpnAddr, []byte(""), make([]byte, 12, 12), make([]byte, mtu)) }() } } + +func protoAddrToNetAddr(addr *Addr) netip.Addr { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], addr.Hi) + binary.BigEndian.PutUint64(b[8:], addr.Lo) + return netip.AddrFrom16(b).Unmap() +} + +func protoV4AddrPortToNetAddrPort(ap *V4AddrPort) netip.AddrPort { + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], ap.Addr) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(ap.Port)) +} + +func protoV6AddrPortToNetAddrPort(ap *V6AddrPort) netip.AddrPort { + b := [16]byte{} + binary.BigEndian.PutUint64(b[:8], ap.Hi) + binary.BigEndian.PutUint64(b[8:], ap.Lo) + return netip.AddrPortFrom(netip.AddrFrom16(b), uint16(ap.Port)) +} + +func netAddrToProtoAddr(addr netip.Addr) *Addr { + b := addr.As16() + return &Addr{ + Hi: binary.BigEndian.Uint64(b[:8]), + Lo: binary.BigEndian.Uint64(b[8:]), + } +} + +func netAddrToProtoV4AddrPort(addr netip.Addr, port uint16) *V4AddrPort { + v4Addr := addr.As4() + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), + Port: uint32(port), + } +} + +func netAddrToProtoV6AddrPort(addr netip.Addr, port uint16) *V6AddrPort { + v6Addr := addr.As16() + return &V6AddrPort{ + Hi: binary.BigEndian.Uint64(v6Addr[:8]), + Lo: binary.BigEndian.Uint64(v6Addr[8:]), + Port: uint32(port), + } +} + +func (d *NebulaMetaDetails) GetRelays() []netip.Addr { + var relays []netip.Addr + if len(d.OldRelayVpnAddrs) > 0 { + b := [4]byte{} + for _, r := range d.OldRelayVpnAddrs { + binary.BigEndian.PutUint32(b[:], r) + relays = append(relays, netip.AddrFrom4(b)) + } + } + + if len(d.RelayVpnAddrs) > 0 { + for _, r := range d.RelayVpnAddrs { + relays = append(relays, protoAddrToNetAddr(r)) + } + } + return relays +} + +// FindNetworkUnion returns the first netip.Addr contained in the list of provided netip.Prefix, if able +func findNetworkUnion(prefixes []netip.Prefix, addrs []netip.Addr) (netip.Addr, bool) { + for i := range prefixes { + for j := range addrs { + if prefixes[i].Contains(addrs[j]) { + return addrs[j], true + } + } + } + return netip.Addr{}, false +} diff --git a/lighthouse_test.go b/lighthouse_test.go index 2599f5f2e..8eeb8ce5b 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -7,6 +7,8 @@ import ( "net/netip" "testing" + "github.com/gaissmai/bart" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" @@ -19,57 +21,48 @@ import ( func TestOldIPv4Only(t *testing.T) { // This test ensures our new ipv6 enabled LH protobuf IpAndPorts works with the old style to enable backwards compatibility b := []byte{8, 129, 130, 132, 80, 16, 10} - var m Ip4AndPort + var m V4AddrPort err := m.Unmarshal(b) assert.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() - assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetIp()) -} - -func TestNewLhQuery(t *testing.T) { - myIp, err := netip.ParseAddr("192.1.1.1") - assert.NoError(t, err) - - // Generating a new lh query should work - a := NewLhQueryByInt(myIp) - - // The result should be a nebulameta protobuf - assert.IsType(t, &NebulaMeta{}, a) - - // It should also Marshal fine - b, err := a.Marshal() - assert.Nil(t, err) - - // and then Unmarshal fine - n := &NebulaMeta{} - err = n.Unmarshal(b) - assert.Nil(t, err) - + assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) } func Test_lhStaticMapping(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - _, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.Nil(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} - _, err = NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/16") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } lh1 := "10.128.0.2" c := config.NewC(l) @@ -79,7 +72,7 @@ func TestReloadLighthouseInterval(t *testing.T) { } c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) lh.ifce = &mockEncWriter{} @@ -99,9 +92,15 @@ func TestReloadLighthouseInterval(t *testing.T) { func BenchmarkLighthouseHandleRequest(b *testing.B) { l := test.NewLogger() myVpnNet := netip.MustParsePrefix("10.128.0.1/0") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } c := config.NewC(l) - lh, err := NewLightHouseFromConfig(context.Background(), l, c, myVpnNet, nil, nil) + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) if !assert.NoError(b, err) { b.Fatal() } @@ -110,46 +109,47 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") vpnIp3 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp3] = NewRemoteList(nil) + lh.addrMap[vpnIp3] = NewRemoteList([]netip.Addr{vpnIp3}, nil) lh.addrMap[vpnIp3].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(hAddr.Addr(), hAddr.Port()), - NewIp4AndPortFromNetIP(hAddr2.Addr(), hAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(hAddr.Addr(), hAddr.Port()), + netAddrToProtoV4AddrPort(hAddr2.Addr(), hAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rAddr := netip.MustParseAddrPort("1.2.2.3:12345") rAddr2 := netip.MustParseAddrPort("1.2.2.3:12346") vpnIp2 := netip.MustParseAddr("0.0.0.3") - lh.addrMap[vpnIp2] = NewRemoteList(nil) + lh.addrMap[vpnIp2] = NewRemoteList([]netip.Addr{vpnIp2}, nil) lh.addrMap[vpnIp2].unlockedSetV4( vpnIp3, vpnIp3, - []*Ip4AndPort{ - NewIp4AndPortFromNetIP(rAddr.Addr(), rAddr.Port()), - NewIp4AndPortFromNetIP(rAddr2.Addr(), rAddr2.Port()), + []*V4AddrPort{ + netAddrToProtoV4AddrPort(rAddr.Addr(), rAddr.Port()), + netAddrToProtoV4AddrPort(rAddr2.Addr(), rAddr2.Port()), }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) mw := &mockEncWriter{} + hi := []netip.Addr{vpnIp2} b.Run("notfound", func(b *testing.B) { lhh := lh.NewRequestHandler() req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 4, - Ip4AndPorts: nil, + OldVpnAddr: 4, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) b.Run("found", func(b *testing.B) { @@ -157,15 +157,15 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: 3, - Ip4AndPorts: nil, + OldVpnAddr: 3, + V4AddrPorts: nil, }, } p, err := req.Marshal() assert.NoError(b, err) for n := 0; n < b.N; n++ { - lhh.HandleRequest(rAddr, vpnIp2, p, mw) + lhh.HandleRequest(rAddr, hi, p, mw) } }) } @@ -197,40 +197,49 @@ func TestLighthouse_Memory(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) + lh.ifce = &mockEncWriter{} assert.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr2}, lhh) r := newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr2) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2) // Ensure we don't accumulate addresses newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr3}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr3) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr3) // Grow it back to 2 newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{myUdpAddr1, myUdpAddr4}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Update a different host and ask about it newLHHostUpdate(theirUdpAddr0, theirVpnIp, []netip.AddrPort{theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4}, lhh) r = newLHHostRequest(theirUdpAddr0, theirVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Have both hosts ask about the other r = newLHHostRequest(theirUdpAddr0, theirVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) r = newLHHostRequest(myUdpAddr0, myVpnIp, theirVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, theirUdpAddr1, theirUdpAddr2, theirUdpAddr3, theirUdpAddr4) // Make sure we didn't get changed r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, myUdpAddr1, myUdpAddr4) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr4) // Ensure proper ordering and limiting // Send 12 addrs, get 10 back, the last 2 removed, allowing the duplicate to remain (clients dedupe) @@ -255,7 +264,7 @@ func TestLighthouse_Memory(t *testing.T) { r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) assertIp4InArray( t, - r.msg.Details.Ip4AndPorts, + r.msg.Details.V4AddrPorts, myUdpAddr1, myUdpAddr2, myUdpAddr3, myUdpAddr4, myUdpAddr5, myUdpAddr5, myUdpAddr6, myUdpAddr7, myUdpAddr8, myUdpAddr9, ) @@ -265,7 +274,7 @@ func TestLighthouse_Memory(t *testing.T) { good := netip.MustParseAddrPort("1.128.0.99:4242") newLHHostUpdate(myUdpAddr0, myVpnIp, []netip.AddrPort{bad1, bad2, good}, lhh) r = newLHHostRequest(myUdpAddr0, myVpnIp, myVpnIp, lhh) - assertIp4InArray(t, r.msg.Details.Ip4AndPorts, good) + assertIp4InArray(t, r.msg.Details.V4AddrPorts, good) } func TestLighthouse_reload(t *testing.T) { @@ -273,7 +282,16 @@ func TestLighthouse_reload(t *testing.T) { c := config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"am_lighthouse": true} c.Settings["listen"] = map[interface{}]interface{}{"port": 4242} - lh, err := NewLightHouseFromConfig(context.Background(), l, c, netip.MustParsePrefix("10.128.0.1/24"), nil, nil) + + myVpnNet := netip.MustParsePrefix("10.128.0.1/24") + nt := new(bart.Table[struct{}]) + nt.Insert(myVpnNet, struct{}{}) + cs := &CertState{ + myVpnNetworks: []netip.Prefix{myVpnNet}, + myVpnNetworksTable: nt, + } + + lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) assert.NoError(t, err) nc := map[interface{}]interface{}{ @@ -295,7 +313,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l req := &NebulaMeta{ Type: NebulaMeta_HostQuery, Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), + OldVpnAddr: binary.BigEndian.Uint32(bip[:]), }, } @@ -308,7 +326,7 @@ func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, l w := &testEncWriter{ metaFilter: &filter, } - lhh.HandleRequest(fromAddr, myVpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{myVpnIp}, b, w) return w.lastReply } @@ -318,13 +336,13 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad req := &NebulaMeta{ Type: NebulaMeta_HostUpdateNotification, Details: &NebulaMetaDetails{ - VpnIp: binary.BigEndian.Uint32(bip[:]), - Ip4AndPorts: make([]*Ip4AndPort, len(addrs)), + OldVpnAddr: binary.BigEndian.Uint32(bip[:]), + V4AddrPorts: make([]*V4AddrPort, len(addrs)), }, } for k, v := range addrs { - req.Details.Ip4AndPorts[k] = NewIp4AndPortFromNetIP(v.Addr(), v.Port()) + req.Details.V4AddrPorts[k] = netAddrToProtoV4AddrPort(v.Addr(), v.Port()) } b, err := req.Marshal() @@ -333,7 +351,7 @@ func newLHHostUpdate(fromAddr netip.AddrPort, vpnIp netip.Addr, addrs []netip.Ad } w := &testEncWriter{} - lhh.HandleRequest(fromAddr, vpnIp, b, w) + lhh.HandleRequest(fromAddr, []netip.Addr{vpnIp}, b, w) } //TODO: this is a RemoteList test @@ -410,8 +428,9 @@ type testLhReply struct { } type testEncWriter struct { - lastReply testLhReply - metaFilter *NebulaMeta_MessageType + lastReply testLhReply + metaFilter *NebulaMeta_MessageType + protocolVersion cert.Version } func (tw *testEncWriter) SendVia(via *HostInfo, relay *Relay, ad, nb, out []byte, nocopy bool) { @@ -426,7 +445,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M tw.lastReply = testLhReply{ nebType: t, nebSubType: st, - vpnIp: hostinfo.vpnIp, + vpnIp: hostinfo.vpnAddrs[0], msg: msg, } } @@ -436,7 +455,7 @@ func (tw *testEncWriter) SendMessageToHostInfo(t header.MessageType, st header.M } } -func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { +func (tw *testEncWriter) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnIp netip.Addr, p, _, _ []byte) { msg := &NebulaMeta{} err := msg.Unmarshal(p) if tw.metaFilter == nil || msg.Type == *tw.metaFilter { @@ -453,17 +472,85 @@ func (tw *testEncWriter) SendMessageToVpnIp(t header.MessageType, st header.Mess } } +func (tw *testEncWriter) GetHostInfo(vpnIp netip.Addr) *HostInfo { + return nil +} + +func (tw *testEncWriter) GetCertState() *CertState { + return &CertState{defaultVersion: tw.protocolVersion} +} + // assertIp4InArray asserts every address in want is at the same position in have and that the lengths match -func assertIp4InArray(t *testing.T, have []*Ip4AndPort, want ...netip.AddrPort) { +func assertIp4InArray(t *testing.T, have []*V4AddrPort, want ...netip.AddrPort) { if !assert.Len(t, have, len(want)) { return } for k, w := range want { //TODO: IPV6-WORK - h := AddrPortFromIp4AndPort(have[k]) + h := protoV4AddrPortToNetAddrPort(have[k]) if !(h == w) { assert.Fail(t, fmt.Sprintf("Response did not contain: %v at %v, found %v", w, k, h)) } } } + +func Test_findNetworkUnion(t *testing.T) { + var out netip.Addr + var ok bool + + tenDot := netip.MustParsePrefix("10.0.0.0/8") + oneSevenTwo := netip.MustParsePrefix("172.16.0.0/16") + fe80 := netip.MustParsePrefix("fe80::/8") + fc00 := netip.MustParsePrefix("fc00::/7") + + a1 := netip.MustParseAddr("10.0.0.1") + afe81 := netip.MustParseAddr("fe80::1") + + //simple + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed lengths + out, ok = findNetworkUnion([]netip.Prefix{tenDot}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //mixed family + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, a1) + + //ordering + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, a1) + out, ok = findNetworkUnion([]netip.Prefix{fe80, tenDot, oneSevenTwo}, []netip.Addr{afe81, a1}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //some mismatches + out, ok = findNetworkUnion([]netip.Prefix{tenDot, oneSevenTwo, fe80}, []netip.Addr{afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1, afe81}) + assert.True(t, ok) + assert.Equal(t, out, afe81) + + //falsey cases + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00, fe80}, []netip.Addr{a1}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{oneSevenTwo, fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) + out, ok = findNetworkUnion([]netip.Prefix{fc00}, []netip.Addr{a1, afe81}) + assert.False(t, ok) +} diff --git a/main.go b/main.go index 8f4535951..1e77a4767 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,6 @@ package nebula import ( "context" - "encoding/binary" "fmt" "net" "net/netip" @@ -61,15 +60,12 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg return nil, util.ContextualizeIfNeeded("Failed to load PKI from config", err) } - certificate := pki.GetCertState().Certificate - fw, err := NewFirewallFromConfig(l, certificate, c) + fw, err := NewFirewallFromConfig(l, pki.getCertState(), c) if err != nil { return nil, util.ContextualizeIfNeeded("Error while loading firewall rules", err) } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - tunCidr := certificate.Networks()[0] - ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { return nil, util.ContextualizeIfNeeded("Error while creating SSH server", err) @@ -132,7 +128,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg deviceFactory = overlay.NewDeviceFromConfig } - tun, err = deviceFactory(c, l, tunCidr, routines) + tun, err = deviceFactory(c, l, pki.getCertState().myVpnNetworks, routines) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to get a tun/tap device", err) } @@ -187,9 +183,9 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } } - hostMap := NewHostMapFromConfig(l, tunCidr, c) + hostMap := NewHostMapFromConfig(l, c) punchy := NewPunchyFromConfig(l, c) - lightHouse, err := NewLightHouseFromConfig(ctx, l, c, tunCidr, udpConns[0], punchy) + lightHouse, err := NewLightHouseFromConfig(ctx, l, c, pki.getCertState(), udpConns[0], punchy) if err != nil { return nil, util.ContextualizeIfNeeded("Failed to initialize lighthouse handler", err) } @@ -232,7 +228,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg Inside: tun, Outside: udpConns[0], pki: pki, - Cipher: c.GetString("cipher", "aes"), Firewall: fw, ServeDns: serveDns, HandshakeManager: handshakeManager, @@ -254,15 +249,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg l: l, } - switch ifConfig.Cipher { - case "aes": - noiseEndianness = binary.BigEndian - case "chachapoly": - noiseEndianness = binary.LittleEndian - default: - return nil, fmt.Errorf("unknown cipher: %v", ifConfig.Cipher) - } - var ifce *Interface if !configTest { ifce, err = NewInterface(ctx, ifConfig) @@ -303,7 +289,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg var dnsStart func() if lightHouse.amLighthouse && serveDns { l.Debugln("Starting dns server") - dnsStart = dnsMain(l, hostMap, c) + dnsStart = dnsMain(l, pki.getCertState(), hostMap, c) } return &Control{ diff --git a/nebula.pb.go b/nebula.pb.go index b3c723a46..2fd2ff665 100644 --- a/nebula.pb.go +++ b/nebula.pb.go @@ -96,7 +96,7 @@ func (x NebulaPing_MessageType) String() string { } func (NebulaPing_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4, 0} + return fileDescriptor_2d65afa7693df5ef, []int{5, 0} } type NebulaControl_MessageType int32 @@ -124,7 +124,7 @@ func (x NebulaControl_MessageType) String() string { } func (NebulaControl_MessageType) EnumDescriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7, 0} + return fileDescriptor_2d65afa7693df5ef, []int{8, 0} } type NebulaMeta struct { @@ -180,11 +180,13 @@ func (m *NebulaMeta) GetDetails() *NebulaMetaDetails { } type NebulaMetaDetails struct { - VpnIp uint32 `protobuf:"varint,1,opt,name=VpnIp,proto3" json:"VpnIp,omitempty"` - Ip4AndPorts []*Ip4AndPort `protobuf:"bytes,2,rep,name=Ip4AndPorts,proto3" json:"Ip4AndPorts,omitempty"` - Ip6AndPorts []*Ip6AndPort `protobuf:"bytes,4,rep,name=Ip6AndPorts,proto3" json:"Ip6AndPorts,omitempty"` - RelayVpnIp []uint32 `protobuf:"varint,5,rep,packed,name=RelayVpnIp,proto3" json:"RelayVpnIp,omitempty"` - Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` + OldVpnAddr uint32 `protobuf:"varint,1,opt,name=OldVpnAddr,proto3" json:"OldVpnAddr,omitempty"` // Deprecated: Do not use. + VpnAddr *Addr `protobuf:"bytes,6,opt,name=VpnAddr,proto3" json:"VpnAddr,omitempty"` + OldRelayVpnAddrs []uint32 `protobuf:"varint,5,rep,packed,name=OldRelayVpnAddrs,proto3" json:"OldRelayVpnAddrs,omitempty"` // Deprecated: Do not use. + RelayVpnAddrs []*Addr `protobuf:"bytes,7,rep,name=RelayVpnAddrs,proto3" json:"RelayVpnAddrs,omitempty"` + V4AddrPorts []*V4AddrPort `protobuf:"bytes,2,rep,name=V4AddrPorts,proto3" json:"V4AddrPorts,omitempty"` + V6AddrPorts []*V6AddrPort `protobuf:"bytes,4,rep,name=V6AddrPorts,proto3" json:"V6AddrPorts,omitempty"` + Counter uint32 `protobuf:"varint,3,opt,name=counter,proto3" json:"counter,omitempty"` } func (m *NebulaMetaDetails) Reset() { *m = NebulaMetaDetails{} } @@ -220,30 +222,46 @@ func (m *NebulaMetaDetails) XXX_DiscardUnknown() { var xxx_messageInfo_NebulaMetaDetails proto.InternalMessageInfo -func (m *NebulaMetaDetails) GetVpnIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldVpnAddr() uint32 { if m != nil { - return m.VpnIp + return m.OldVpnAddr } return 0 } -func (m *NebulaMetaDetails) GetIp4AndPorts() []*Ip4AndPort { +func (m *NebulaMetaDetails) GetVpnAddr() *Addr { if m != nil { - return m.Ip4AndPorts + return m.VpnAddr } return nil } -func (m *NebulaMetaDetails) GetIp6AndPorts() []*Ip6AndPort { +// Deprecated: Do not use. +func (m *NebulaMetaDetails) GetOldRelayVpnAddrs() []uint32 { if m != nil { - return m.Ip6AndPorts + return m.OldRelayVpnAddrs } return nil } -func (m *NebulaMetaDetails) GetRelayVpnIp() []uint32 { +func (m *NebulaMetaDetails) GetRelayVpnAddrs() []*Addr { if m != nil { - return m.RelayVpnIp + return m.RelayVpnAddrs + } + return nil +} + +func (m *NebulaMetaDetails) GetV4AddrPorts() []*V4AddrPort { + if m != nil { + return m.V4AddrPorts + } + return nil +} + +func (m *NebulaMetaDetails) GetV6AddrPorts() []*V6AddrPort { + if m != nil { + return m.V6AddrPorts } return nil } @@ -255,23 +273,23 @@ func (m *NebulaMetaDetails) GetCounter() uint32 { return 0 } -type Ip4AndPort struct { - Ip uint32 `protobuf:"varint,1,opt,name=Ip,proto3" json:"Ip,omitempty"` - Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +type Addr struct { + Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` + Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` } -func (m *Ip4AndPort) Reset() { *m = Ip4AndPort{} } -func (m *Ip4AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip4AndPort) ProtoMessage() {} -func (*Ip4AndPort) Descriptor() ([]byte, []int) { +func (m *Addr) Reset() { *m = Addr{} } +func (m *Addr) String() string { return proto.CompactTextString(m) } +func (*Addr) ProtoMessage() {} +func (*Addr) Descriptor() ([]byte, []int) { return fileDescriptor_2d65afa7693df5ef, []int{2} } -func (m *Ip4AndPort) XXX_Unmarshal(b []byte) error { +func (m *Addr) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *Addr) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip4AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_Addr.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -281,50 +299,102 @@ func (m *Ip4AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip4AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip4AndPort.Merge(m, src) +func (m *Addr) XXX_Merge(src proto.Message) { + xxx_messageInfo_Addr.Merge(m, src) } -func (m *Ip4AndPort) XXX_Size() int { +func (m *Addr) XXX_Size() int { return m.Size() } -func (m *Ip4AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip4AndPort.DiscardUnknown(m) +func (m *Addr) XXX_DiscardUnknown() { + xxx_messageInfo_Addr.DiscardUnknown(m) } -var xxx_messageInfo_Ip4AndPort proto.InternalMessageInfo +var xxx_messageInfo_Addr proto.InternalMessageInfo -func (m *Ip4AndPort) GetIp() uint32 { +func (m *Addr) GetHi() uint64 { if m != nil { - return m.Ip + return m.Hi } return 0 } -func (m *Ip4AndPort) GetPort() uint32 { +func (m *Addr) GetLo() uint64 { + if m != nil { + return m.Lo + } + return 0 +} + +type V4AddrPort struct { + Addr uint32 `protobuf:"varint,1,opt,name=Addr,proto3" json:"Addr,omitempty"` + Port uint32 `protobuf:"varint,2,opt,name=Port,proto3" json:"Port,omitempty"` +} + +func (m *V4AddrPort) Reset() { *m = V4AddrPort{} } +func (m *V4AddrPort) String() string { return proto.CompactTextString(m) } +func (*V4AddrPort) ProtoMessage() {} +func (*V4AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{3} +} +func (m *V4AddrPort) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *V4AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_V4AddrPort.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *V4AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V4AddrPort.Merge(m, src) +} +func (m *V4AddrPort) XXX_Size() int { + return m.Size() +} +func (m *V4AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V4AddrPort.DiscardUnknown(m) +} + +var xxx_messageInfo_V4AddrPort proto.InternalMessageInfo + +func (m *V4AddrPort) GetAddr() uint32 { + if m != nil { + return m.Addr + } + return 0 +} + +func (m *V4AddrPort) GetPort() uint32 { if m != nil { return m.Port } return 0 } -type Ip6AndPort struct { +type V6AddrPort struct { Hi uint64 `protobuf:"varint,1,opt,name=Hi,proto3" json:"Hi,omitempty"` Lo uint64 `protobuf:"varint,2,opt,name=Lo,proto3" json:"Lo,omitempty"` Port uint32 `protobuf:"varint,3,opt,name=Port,proto3" json:"Port,omitempty"` } -func (m *Ip6AndPort) Reset() { *m = Ip6AndPort{} } -func (m *Ip6AndPort) String() string { return proto.CompactTextString(m) } -func (*Ip6AndPort) ProtoMessage() {} -func (*Ip6AndPort) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{3} +func (m *V6AddrPort) Reset() { *m = V6AddrPort{} } +func (m *V6AddrPort) String() string { return proto.CompactTextString(m) } +func (*V6AddrPort) ProtoMessage() {} +func (*V6AddrPort) Descriptor() ([]byte, []int) { + return fileDescriptor_2d65afa7693df5ef, []int{4} } -func (m *Ip6AndPort) XXX_Unmarshal(b []byte) error { +func (m *V6AddrPort) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) } -func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { +func (m *V6AddrPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { if deterministic { - return xxx_messageInfo_Ip6AndPort.Marshal(b, m, deterministic) + return xxx_messageInfo_V6AddrPort.Marshal(b, m, deterministic) } else { b = b[:cap(b)] n, err := m.MarshalToSizedBuffer(b) @@ -334,33 +404,33 @@ func (m *Ip6AndPort) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return b[:n], nil } } -func (m *Ip6AndPort) XXX_Merge(src proto.Message) { - xxx_messageInfo_Ip6AndPort.Merge(m, src) +func (m *V6AddrPort) XXX_Merge(src proto.Message) { + xxx_messageInfo_V6AddrPort.Merge(m, src) } -func (m *Ip6AndPort) XXX_Size() int { +func (m *V6AddrPort) XXX_Size() int { return m.Size() } -func (m *Ip6AndPort) XXX_DiscardUnknown() { - xxx_messageInfo_Ip6AndPort.DiscardUnknown(m) +func (m *V6AddrPort) XXX_DiscardUnknown() { + xxx_messageInfo_V6AddrPort.DiscardUnknown(m) } -var xxx_messageInfo_Ip6AndPort proto.InternalMessageInfo +var xxx_messageInfo_V6AddrPort proto.InternalMessageInfo -func (m *Ip6AndPort) GetHi() uint64 { +func (m *V6AddrPort) GetHi() uint64 { if m != nil { return m.Hi } return 0 } -func (m *Ip6AndPort) GetLo() uint64 { +func (m *V6AddrPort) GetLo() uint64 { if m != nil { return m.Lo } return 0 } -func (m *Ip6AndPort) GetPort() uint32 { +func (m *V6AddrPort) GetPort() uint32 { if m != nil { return m.Port } @@ -376,7 +446,7 @@ func (m *NebulaPing) Reset() { *m = NebulaPing{} } func (m *NebulaPing) String() string { return proto.CompactTextString(m) } func (*NebulaPing) ProtoMessage() {} func (*NebulaPing) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{4} + return fileDescriptor_2d65afa7693df5ef, []int{5} } func (m *NebulaPing) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -428,7 +498,7 @@ func (m *NebulaHandshake) Reset() { *m = NebulaHandshake{} } func (m *NebulaHandshake) String() string { return proto.CompactTextString(m) } func (*NebulaHandshake) ProtoMessage() {} func (*NebulaHandshake) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{5} + return fileDescriptor_2d65afa7693df5ef, []int{6} } func (m *NebulaHandshake) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -477,13 +547,14 @@ type NebulaHandshakeDetails struct { ResponderIndex uint32 `protobuf:"varint,3,opt,name=ResponderIndex,proto3" json:"ResponderIndex,omitempty"` Cookie uint64 `protobuf:"varint,4,opt,name=Cookie,proto3" json:"Cookie,omitempty"` Time uint64 `protobuf:"varint,5,opt,name=Time,proto3" json:"Time,omitempty"` + CertVersion uint32 `protobuf:"varint,8,opt,name=CertVersion,proto3" json:"CertVersion,omitempty"` } func (m *NebulaHandshakeDetails) Reset() { *m = NebulaHandshakeDetails{} } func (m *NebulaHandshakeDetails) String() string { return proto.CompactTextString(m) } func (*NebulaHandshakeDetails) ProtoMessage() {} func (*NebulaHandshakeDetails) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{6} + return fileDescriptor_2d65afa7693df5ef, []int{7} } func (m *NebulaHandshakeDetails) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -547,19 +618,28 @@ func (m *NebulaHandshakeDetails) GetTime() uint64 { return 0 } +func (m *NebulaHandshakeDetails) GetCertVersion() uint32 { + if m != nil { + return m.CertVersion + } + return 0 +} + type NebulaControl struct { Type NebulaControl_MessageType `protobuf:"varint,1,opt,name=Type,proto3,enum=nebula.NebulaControl_MessageType" json:"Type,omitempty"` InitiatorRelayIndex uint32 `protobuf:"varint,2,opt,name=InitiatorRelayIndex,proto3" json:"InitiatorRelayIndex,omitempty"` ResponderRelayIndex uint32 `protobuf:"varint,3,opt,name=ResponderRelayIndex,proto3" json:"ResponderRelayIndex,omitempty"` - RelayToIp uint32 `protobuf:"varint,4,opt,name=RelayToIp,proto3" json:"RelayToIp,omitempty"` - RelayFromIp uint32 `protobuf:"varint,5,opt,name=RelayFromIp,proto3" json:"RelayFromIp,omitempty"` + OldRelayToAddr uint32 `protobuf:"varint,4,opt,name=OldRelayToAddr,proto3" json:"OldRelayToAddr,omitempty"` // Deprecated: Do not use. + OldRelayFromAddr uint32 `protobuf:"varint,5,opt,name=OldRelayFromAddr,proto3" json:"OldRelayFromAddr,omitempty"` // Deprecated: Do not use. + RelayToAddr *Addr `protobuf:"bytes,6,opt,name=RelayToAddr,proto3" json:"RelayToAddr,omitempty"` + RelayFromAddr *Addr `protobuf:"bytes,7,opt,name=RelayFromAddr,proto3" json:"RelayFromAddr,omitempty"` } func (m *NebulaControl) Reset() { *m = NebulaControl{} } func (m *NebulaControl) String() string { return proto.CompactTextString(m) } func (*NebulaControl) ProtoMessage() {} func (*NebulaControl) Descriptor() ([]byte, []int) { - return fileDescriptor_2d65afa7693df5ef, []int{7} + return fileDescriptor_2d65afa7693df5ef, []int{8} } func (m *NebulaControl) XXX_Unmarshal(b []byte) error { return m.Unmarshal(b) @@ -609,28 +689,45 @@ func (m *NebulaControl) GetResponderRelayIndex() uint32 { return 0 } -func (m *NebulaControl) GetRelayToIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayToAddr() uint32 { if m != nil { - return m.RelayToIp + return m.OldRelayToAddr } return 0 } -func (m *NebulaControl) GetRelayFromIp() uint32 { +// Deprecated: Do not use. +func (m *NebulaControl) GetOldRelayFromAddr() uint32 { if m != nil { - return m.RelayFromIp + return m.OldRelayFromAddr } return 0 } +func (m *NebulaControl) GetRelayToAddr() *Addr { + if m != nil { + return m.RelayToAddr + } + return nil +} + +func (m *NebulaControl) GetRelayFromAddr() *Addr { + if m != nil { + return m.RelayFromAddr + } + return nil +} + func init() { proto.RegisterEnum("nebula.NebulaMeta_MessageType", NebulaMeta_MessageType_name, NebulaMeta_MessageType_value) proto.RegisterEnum("nebula.NebulaPing_MessageType", NebulaPing_MessageType_name, NebulaPing_MessageType_value) proto.RegisterEnum("nebula.NebulaControl_MessageType", NebulaControl_MessageType_name, NebulaControl_MessageType_value) proto.RegisterType((*NebulaMeta)(nil), "nebula.NebulaMeta") proto.RegisterType((*NebulaMetaDetails)(nil), "nebula.NebulaMetaDetails") - proto.RegisterType((*Ip4AndPort)(nil), "nebula.Ip4AndPort") - proto.RegisterType((*Ip6AndPort)(nil), "nebula.Ip6AndPort") + proto.RegisterType((*Addr)(nil), "nebula.Addr") + proto.RegisterType((*V4AddrPort)(nil), "nebula.V4AddrPort") + proto.RegisterType((*V6AddrPort)(nil), "nebula.V6AddrPort") proto.RegisterType((*NebulaPing)(nil), "nebula.NebulaPing") proto.RegisterType((*NebulaHandshake)(nil), "nebula.NebulaHandshake") proto.RegisterType((*NebulaHandshakeDetails)(nil), "nebula.NebulaHandshakeDetails") @@ -640,52 +737,57 @@ func init() { func init() { proto.RegisterFile("nebula.proto", fileDescriptor_2d65afa7693df5ef) } var fileDescriptor_2d65afa7693df5ef = []byte{ - // 707 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x7c, 0x54, 0x4d, 0x6f, 0xda, 0x4a, - 0x14, 0xc5, 0xc6, 0x7c, 0x5d, 0x02, 0xf1, 0xbb, 0x79, 0x8f, 0x07, 0x4f, 0xaf, 0x16, 0xf5, 0xa2, - 0x62, 0x45, 0x22, 0x92, 0x46, 0x5d, 0x36, 0xa5, 0xaa, 0x20, 0x4a, 0x22, 0x3a, 0x4a, 0x5b, 0xa9, - 0x9b, 0x6a, 0x62, 0xa6, 0xc1, 0x02, 0x3c, 0x8e, 0x3d, 0x54, 0xe1, 0x5f, 0xf4, 0xc7, 0xe4, 0x47, - 0x74, 0xd7, 0x2c, 0xbb, 0xac, 0x92, 0x65, 0x97, 0xfd, 0x03, 0xd5, 0x8c, 0xc1, 0x36, 0x84, 0x76, - 0x37, 0xe7, 0xde, 0x73, 0x66, 0xce, 0x9c, 0xb9, 0x36, 0x6c, 0x79, 0xec, 0x62, 0x36, 0xa1, 0x6d, - 0x3f, 0xe0, 0x82, 0x63, 0x3e, 0x42, 0xf6, 0x0f, 0x1d, 0xe0, 0x4c, 0x2d, 0x4f, 0x99, 0xa0, 0xd8, - 0x01, 0xe3, 0x7c, 0xee, 0xb3, 0xba, 0xd6, 0xd4, 0x5a, 0xd5, 0x8e, 0xd5, 0x5e, 0x68, 0x12, 0x46, - 0xfb, 0x94, 0x85, 0x21, 0xbd, 0x64, 0x92, 0x45, 0x14, 0x17, 0xf7, 0xa1, 0xf0, 0x92, 0x09, 0xea, - 0x4e, 0xc2, 0xba, 0xde, 0xd4, 0x5a, 0xe5, 0x4e, 0xe3, 0xa1, 0x6c, 0x41, 0x20, 0x4b, 0xa6, 0xfd, - 0x53, 0x83, 0x72, 0x6a, 0x2b, 0x2c, 0x82, 0x71, 0xc6, 0x3d, 0x66, 0x66, 0xb0, 0x02, 0xa5, 0x1e, - 0x0f, 0xc5, 0xeb, 0x19, 0x0b, 0xe6, 0xa6, 0x86, 0x08, 0xd5, 0x18, 0x12, 0xe6, 0x4f, 0xe6, 0xa6, - 0x8e, 0xff, 0x41, 0x4d, 0xd6, 0xde, 0xf8, 0x43, 0x2a, 0xd8, 0x19, 0x17, 0xee, 0x47, 0xd7, 0xa1, - 0xc2, 0xe5, 0x9e, 0x99, 0xc5, 0x06, 0xfc, 0x23, 0x7b, 0xa7, 0xfc, 0x13, 0x1b, 0xae, 0xb4, 0x8c, - 0x65, 0x6b, 0x30, 0xf3, 0x9c, 0xd1, 0x4a, 0x2b, 0x87, 0x55, 0x00, 0xd9, 0x7a, 0x37, 0xe2, 0x74, - 0xea, 0x9a, 0x79, 0xdc, 0x81, 0xed, 0x04, 0x47, 0xc7, 0x16, 0xa4, 0xb3, 0x01, 0x15, 0xa3, 0xee, - 0x88, 0x39, 0x63, 0xb3, 0x28, 0x9d, 0xc5, 0x30, 0xa2, 0x94, 0xf0, 0x11, 0x34, 0x36, 0x3b, 0x3b, - 0x72, 0xc6, 0x26, 0xd8, 0x5f, 0x35, 0xf8, 0xeb, 0x41, 0x28, 0xf8, 0x37, 0xe4, 0xde, 0xfa, 0x5e, - 0xdf, 0x57, 0xa9, 0x57, 0x48, 0x04, 0xf0, 0x00, 0xca, 0x7d, 0xff, 0xe0, 0xc8, 0x1b, 0x0e, 0x78, - 0x20, 0x64, 0xb4, 0xd9, 0x56, 0xb9, 0x83, 0xcb, 0x68, 0x93, 0x16, 0x49, 0xd3, 0x22, 0xd5, 0x61, - 0xac, 0x32, 0xd6, 0x55, 0x87, 0x29, 0x55, 0x4c, 0x43, 0x0b, 0x80, 0xb0, 0x09, 0x9d, 0x47, 0x36, - 0x72, 0xcd, 0x6c, 0xab, 0x42, 0x52, 0x15, 0xac, 0x43, 0xc1, 0xe1, 0x33, 0x4f, 0xb0, 0xa0, 0x9e, - 0x55, 0x1e, 0x97, 0xd0, 0xde, 0x03, 0x48, 0x8e, 0xc7, 0x2a, 0xe8, 0xf1, 0x35, 0xf4, 0xbe, 0x8f, - 0x08, 0x86, 0xac, 0xab, 0xb9, 0xa8, 0x10, 0xb5, 0xb6, 0x9f, 0x4b, 0xc5, 0x61, 0x4a, 0xd1, 0x73, - 0x95, 0xc2, 0x20, 0x7a, 0xcf, 0x95, 0xf8, 0x84, 0x2b, 0xbe, 0x41, 0xf4, 0x13, 0x1e, 0xef, 0x90, - 0x4d, 0xed, 0x70, 0xbd, 0x1c, 0xd9, 0x81, 0xeb, 0x5d, 0xfe, 0x79, 0x64, 0x25, 0x63, 0xc3, 0xc8, - 0x22, 0x18, 0xe7, 0xee, 0x94, 0x2d, 0xce, 0x51, 0x6b, 0xdb, 0x7e, 0x30, 0x90, 0x52, 0x6c, 0x66, - 0xb0, 0x04, 0xb9, 0xe8, 0x79, 0x35, 0xfb, 0x03, 0x6c, 0x47, 0xfb, 0xf6, 0xa8, 0x37, 0x0c, 0x47, - 0x74, 0xcc, 0xf0, 0x59, 0x32, 0xfd, 0x9a, 0x9a, 0xfe, 0x35, 0x07, 0x31, 0x73, 0xfd, 0x13, 0x90, - 0x26, 0x7a, 0x53, 0xea, 0x28, 0x13, 0x5b, 0x44, 0xad, 0xed, 0x1b, 0x0d, 0x6a, 0x9b, 0x75, 0x92, - 0xde, 0x65, 0x81, 0x50, 0xa7, 0x6c, 0x11, 0xb5, 0xc6, 0x27, 0x50, 0xed, 0x7b, 0xae, 0x70, 0xa9, - 0xe0, 0x41, 0xdf, 0x1b, 0xb2, 0xeb, 0x45, 0xd2, 0x6b, 0x55, 0xc9, 0x23, 0x2c, 0xf4, 0xb9, 0x37, - 0x64, 0x0b, 0x5e, 0x94, 0xe7, 0x5a, 0x15, 0x6b, 0x90, 0xef, 0x72, 0x3e, 0x76, 0x59, 0xdd, 0x50, - 0xc9, 0x2c, 0x50, 0x9c, 0x57, 0x2e, 0xc9, 0xeb, 0xd8, 0x28, 0xe6, 0xcd, 0xc2, 0xb1, 0x51, 0x2c, - 0x98, 0x45, 0xfb, 0x46, 0x87, 0x4a, 0x64, 0xbb, 0xcb, 0x3d, 0x11, 0xf0, 0x09, 0x3e, 0x5d, 0x79, - 0x95, 0xc7, 0xab, 0x99, 0x2c, 0x48, 0x1b, 0x1e, 0x66, 0x0f, 0x76, 0x62, 0xeb, 0x6a, 0xfe, 0xd2, - 0xb7, 0xda, 0xd4, 0x92, 0x8a, 0xf8, 0x12, 0x29, 0x45, 0x74, 0xbf, 0x4d, 0x2d, 0xfc, 0x1f, 0x4a, - 0x0a, 0x9d, 0xf3, 0xbe, 0xaf, 0xee, 0x59, 0x21, 0x49, 0x01, 0x9b, 0x50, 0x56, 0xe0, 0x55, 0xc0, - 0xa7, 0xea, 0x5b, 0x90, 0xfd, 0x74, 0xc9, 0xee, 0xfd, 0xee, 0xcf, 0x55, 0x03, 0xec, 0x06, 0x8c, - 0x0a, 0xa6, 0xd8, 0x84, 0x5d, 0xcd, 0x58, 0x28, 0x4c, 0x0d, 0xff, 0x85, 0x9d, 0x95, 0xba, 0xb4, - 0x14, 0x32, 0x53, 0x7f, 0xb1, 0xff, 0xe5, 0xce, 0xd2, 0x6e, 0xef, 0x2c, 0xed, 0xfb, 0x9d, 0xa5, - 0x7d, 0xbe, 0xb7, 0x32, 0xb7, 0xf7, 0x56, 0xe6, 0xdb, 0xbd, 0x95, 0x79, 0xdf, 0xb8, 0x74, 0xc5, - 0x68, 0x76, 0xd1, 0x76, 0xf8, 0x74, 0x37, 0x9c, 0x50, 0x67, 0x3c, 0xba, 0xda, 0x8d, 0x22, 0xbc, - 0xc8, 0xab, 0x1f, 0xf8, 0xfe, 0xaf, 0x00, 0x00, 0x00, 0xff, 0xff, 0x17, 0x56, 0x28, 0x74, 0xd0, - 0x05, 0x00, 0x00, + // 785 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x55, 0xcd, 0x6e, 0xeb, 0x44, + 0x14, 0x8e, 0x1d, 0x27, 0x4e, 0x4f, 0x7e, 0xae, 0x39, 0x15, 0xc1, 0x41, 0x22, 0x0a, 0x5e, 0x54, + 0x57, 0x2c, 0x72, 0x51, 0x5a, 0xae, 0x58, 0x72, 0x1b, 0x84, 0xd2, 0xaa, 0x3f, 0x61, 0x54, 0x8a, + 0xc4, 0x06, 0xb9, 0xf6, 0xd0, 0x58, 0x71, 0x3c, 0xa9, 0x3d, 0x41, 0xcd, 0x5b, 0xf0, 0x30, 0x3c, + 0x04, 0xec, 0xba, 0x42, 0x2c, 0x51, 0xbb, 0x64, 0xc9, 0x0b, 0xa0, 0x19, 0xff, 0x27, 0x86, 0xbb, + 0x9b, 0x73, 0xbe, 0xef, 0x3b, 0x73, 0xe6, 0xf3, 0x9c, 0x31, 0x74, 0x02, 0x7a, 0xb7, 0xf1, 0xed, + 0xf1, 0x3a, 0x64, 0x9c, 0x61, 0x33, 0x8e, 0xac, 0xbf, 0x55, 0x80, 0x2b, 0xb9, 0xbc, 0xa4, 0xdc, + 0xc6, 0x09, 0x68, 0x37, 0xdb, 0x35, 0x35, 0x95, 0x91, 0xf2, 0xba, 0x37, 0x19, 0x8e, 0x13, 0x4d, + 0xce, 0x18, 0x5f, 0xd2, 0x28, 0xb2, 0xef, 0xa9, 0x60, 0x11, 0xc9, 0xc5, 0x63, 0xd0, 0xbf, 0xa6, + 0xdc, 0xf6, 0xfc, 0xc8, 0x54, 0x47, 0xca, 0xeb, 0xf6, 0x64, 0xb0, 0x2f, 0x4b, 0x08, 0x24, 0x65, + 0x5a, 0xff, 0x28, 0xd0, 0x2e, 0x94, 0xc2, 0x16, 0x68, 0x57, 0x2c, 0xa0, 0x46, 0x0d, 0xbb, 0x70, + 0x30, 0x63, 0x11, 0xff, 0x76, 0x43, 0xc3, 0xad, 0xa1, 0x20, 0x42, 0x2f, 0x0b, 0x09, 0x5d, 0xfb, + 0x5b, 0x43, 0xc5, 0x8f, 0xa1, 0x2f, 0x72, 0xdf, 0xad, 0x5d, 0x9b, 0xd3, 0x2b, 0xc6, 0xbd, 0x9f, + 0x3c, 0xc7, 0xe6, 0x1e, 0x0b, 0x8c, 0x3a, 0x0e, 0xe0, 0x43, 0x81, 0x5d, 0xb2, 0x9f, 0xa9, 0x5b, + 0x82, 0xb4, 0x14, 0x9a, 0x6f, 0x02, 0x67, 0x51, 0x82, 0x1a, 0xd8, 0x03, 0x10, 0xd0, 0xf7, 0x0b, + 0x66, 0xaf, 0x3c, 0xa3, 0x89, 0x87, 0xf0, 0x2a, 0x8f, 0xe3, 0x6d, 0x75, 0xd1, 0xd9, 0xdc, 0xe6, + 0x8b, 0xe9, 0x82, 0x3a, 0x4b, 0xa3, 0x25, 0x3a, 0xcb, 0xc2, 0x98, 0x72, 0x80, 0x9f, 0xc0, 0xa0, + 0xba, 0xb3, 0x77, 0xce, 0xd2, 0x00, 0xeb, 0x77, 0x15, 0x3e, 0xd8, 0x33, 0x05, 0x2d, 0x80, 0x6b, + 0xdf, 0xbd, 0x5d, 0x07, 0xef, 0x5c, 0x37, 0x94, 0xd6, 0x77, 0x4f, 0x55, 0x53, 0x21, 0x85, 0x2c, + 0x1e, 0x81, 0x9e, 0x12, 0x9a, 0xd2, 0xe4, 0x4e, 0x6a, 0xb2, 0xc8, 0x91, 0x14, 0xc4, 0x31, 0x18, + 0xd7, 0xbe, 0x4b, 0xa8, 0x6f, 0x6f, 0x93, 0x54, 0x64, 0x36, 0x46, 0xf5, 0xa4, 0xe2, 0x1e, 0x86, + 0x13, 0xe8, 0x96, 0xc9, 0xfa, 0xa8, 0xbe, 0x57, 0xbd, 0x4c, 0xc1, 0x13, 0x68, 0xdf, 0x9e, 0x88, + 0xe5, 0x9c, 0x85, 0x5c, 0x7c, 0x74, 0xa1, 0xc0, 0x54, 0x91, 0x43, 0xa4, 0x48, 0x93, 0xaa, 0xb7, + 0xb9, 0x4a, 0xdb, 0x51, 0xbd, 0x2d, 0xa8, 0x72, 0x1a, 0x9a, 0xa0, 0x3b, 0x6c, 0x13, 0x70, 0x1a, + 0x9a, 0x75, 0x61, 0x0c, 0x49, 0x43, 0xeb, 0x08, 0x34, 0x79, 0xe2, 0x1e, 0xa8, 0x33, 0x4f, 0xba, + 0xa6, 0x11, 0x75, 0xe6, 0x89, 0xf8, 0x82, 0xc9, 0x9b, 0xa8, 0x11, 0xf5, 0x82, 0x59, 0x27, 0x00, + 0x79, 0x1b, 0x88, 0xb1, 0x2a, 0x76, 0x99, 0xc4, 0x15, 0x10, 0x34, 0x81, 0x49, 0x4d, 0x97, 0xc8, + 0xb5, 0xf5, 0x15, 0x40, 0xde, 0xc6, 0xfb, 0xf6, 0xc8, 0x2a, 0xd4, 0x0b, 0x15, 0x1e, 0xd3, 0xc1, + 0x9a, 0x7b, 0xc1, 0xfd, 0xff, 0x0f, 0x96, 0x60, 0x54, 0x0c, 0x16, 0x82, 0x76, 0xe3, 0xad, 0x68, + 0xb2, 0x8f, 0x5c, 0x5b, 0xd6, 0xde, 0xd8, 0x08, 0xb1, 0x51, 0xc3, 0x03, 0x68, 0xc4, 0x97, 0x50, + 0xb1, 0x7e, 0x84, 0x57, 0x71, 0xdd, 0x99, 0x1d, 0xb8, 0xd1, 0xc2, 0x5e, 0x52, 0xfc, 0x32, 0x9f, + 0x51, 0x45, 0x5e, 0x9f, 0x9d, 0x0e, 0x32, 0xe6, 0xee, 0xa0, 0x8a, 0x26, 0x66, 0x2b, 0xdb, 0x91, + 0x4d, 0x74, 0x88, 0x5c, 0x5b, 0x7f, 0x28, 0xd0, 0xaf, 0xd6, 0x09, 0xfa, 0x94, 0x86, 0x5c, 0xee, + 0xd2, 0x21, 0x72, 0x8d, 0x47, 0xd0, 0x3b, 0x0b, 0x3c, 0xee, 0xd9, 0x9c, 0x85, 0x67, 0x81, 0x4b, + 0x1f, 0x13, 0xa7, 0x77, 0xb2, 0x82, 0x47, 0x68, 0xb4, 0x66, 0x81, 0x4b, 0x13, 0x5e, 0xec, 0xe7, + 0x4e, 0x16, 0xfb, 0xd0, 0x9c, 0x32, 0xb6, 0xf4, 0xa8, 0xa9, 0x49, 0x67, 0x92, 0x28, 0xf3, 0xab, + 0x91, 0xfb, 0x85, 0x23, 0x68, 0x8b, 0x1e, 0x6e, 0x69, 0x18, 0x79, 0x2c, 0x30, 0x5b, 0xb2, 0x60, + 0x31, 0x75, 0xae, 0xb5, 0x9a, 0x86, 0x7e, 0xae, 0xb5, 0x74, 0xa3, 0x65, 0xfd, 0x5a, 0x87, 0x6e, + 0x7c, 0xb0, 0x29, 0x0b, 0x78, 0xc8, 0x7c, 0xfc, 0xa2, 0xf4, 0xdd, 0x3e, 0x2d, 0xbb, 0x96, 0x90, + 0x2a, 0x3e, 0xdd, 0xe7, 0x70, 0x98, 0x1d, 0x4e, 0x0e, 0x4f, 0xf1, 0xdc, 0x55, 0x90, 0x50, 0x64, + 0xc7, 0x2c, 0x28, 0x62, 0x07, 0xaa, 0x20, 0xfc, 0x0c, 0x7a, 0xe9, 0x38, 0xdf, 0x30, 0x79, 0xa9, + 0xb5, 0xec, 0xe9, 0xd8, 0x41, 0x8a, 0xcf, 0xc2, 0x37, 0x21, 0x5b, 0x49, 0x76, 0x23, 0x63, 0xef, + 0x61, 0x38, 0x86, 0x76, 0xb1, 0x70, 0xd5, 0x93, 0x53, 0x24, 0x64, 0xcf, 0x48, 0x56, 0x5c, 0xaf, + 0x50, 0x94, 0x29, 0xd6, 0xec, 0xbf, 0xfe, 0x00, 0x7d, 0xc0, 0x69, 0x48, 0x6d, 0x4e, 0x25, 0x9f, + 0xd0, 0x87, 0x0d, 0x8d, 0xb8, 0xa1, 0xe0, 0x47, 0x70, 0x58, 0xca, 0x0b, 0x4b, 0x22, 0x6a, 0xa8, + 0xa7, 0xc7, 0xbf, 0x3d, 0x0f, 0x95, 0xa7, 0xe7, 0xa1, 0xf2, 0xd7, 0xf3, 0x50, 0xf9, 0xe5, 0x65, + 0x58, 0x7b, 0x7a, 0x19, 0xd6, 0xfe, 0x7c, 0x19, 0xd6, 0x7e, 0x18, 0xdc, 0x7b, 0x7c, 0xb1, 0xb9, + 0x1b, 0x3b, 0x6c, 0xf5, 0x26, 0xf2, 0x6d, 0x67, 0xb9, 0x78, 0x78, 0x13, 0xb7, 0x74, 0xd7, 0x94, + 0x3f, 0xc2, 0xe3, 0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xea, 0x6f, 0xbc, 0x50, 0x18, 0x07, 0x00, + 0x00, } func (m *NebulaMeta) Marshal() (dAtA []byte, err error) { @@ -748,28 +850,54 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if len(m.RelayVpnIp) > 0 { - dAtA3 := make([]byte, len(m.RelayVpnIp)*10) - var j2 int - for _, num := range m.RelayVpnIp { + if len(m.RelayVpnAddrs) > 0 { + for iNdEx := len(m.RelayVpnAddrs) - 1; iNdEx >= 0; iNdEx-- { + { + size, err := m.RelayVpnAddrs[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + } + if m.VpnAddr != nil { + { + size, err := m.VpnAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if len(m.OldRelayVpnAddrs) > 0 { + dAtA4 := make([]byte, len(m.OldRelayVpnAddrs)*10) + var j3 int + for _, num := range m.OldRelayVpnAddrs { for num >= 1<<7 { - dAtA3[j2] = uint8(uint64(num)&0x7f | 0x80) + dAtA4[j3] = uint8(uint64(num)&0x7f | 0x80) num >>= 7 - j2++ + j3++ } - dAtA3[j2] = uint8(num) - j2++ + dAtA4[j3] = uint8(num) + j3++ } - i -= j2 - copy(dAtA[i:], dAtA3[:j2]) - i = encodeVarintNebula(dAtA, i, uint64(j2)) + i -= j3 + copy(dAtA[i:], dAtA4[:j3]) + i = encodeVarintNebula(dAtA, i, uint64(j3)) i-- dAtA[i] = 0x2a } - if len(m.Ip6AndPorts) > 0 { - for iNdEx := len(m.Ip6AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V6AddrPorts) > 0 { + for iNdEx := len(m.V6AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip6AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V6AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -785,10 +913,10 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x18 } - if len(m.Ip4AndPorts) > 0 { - for iNdEx := len(m.Ip4AndPorts) - 1; iNdEx >= 0; iNdEx-- { + if len(m.V4AddrPorts) > 0 { + for iNdEx := len(m.V4AddrPorts) - 1; iNdEx >= 0; iNdEx-- { { - size, err := m.Ip4AndPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) + size, err := m.V4AddrPorts[iNdEx].MarshalToSizedBuffer(dAtA[:i]) if err != nil { return 0, err } @@ -799,15 +927,48 @@ func (m *NebulaMetaDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) { dAtA[i] = 0x12 } } - if m.VpnIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldVpnAddr)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + +func (m *Addr) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Addr) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Addr) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Lo != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Lo)) + i-- + dAtA[i] = 0x10 + } + if m.Hi != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Hi)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { +func (m *V4AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -817,12 +978,12 @@ func (m *Ip4AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip4AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V4AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V4AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -832,15 +993,15 @@ func (m *Ip4AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x10 } - if m.Ip != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.Ip)) + if m.Addr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.Addr)) i-- dAtA[i] = 0x8 } return len(dAtA) - i, nil } -func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { +func (m *V6AddrPort) Marshal() (dAtA []byte, err error) { size := m.Size() dAtA = make([]byte, size) n, err := m.MarshalToSizedBuffer(dAtA[:size]) @@ -850,12 +1011,12 @@ func (m *Ip6AndPort) Marshal() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *Ip6AndPort) MarshalTo(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalTo(dAtA []byte) (int, error) { size := m.Size() return m.MarshalToSizedBuffer(dAtA[:size]) } -func (m *Ip6AndPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { +func (m *V6AddrPort) MarshalToSizedBuffer(dAtA []byte) (int, error) { i := len(dAtA) _ = i var l int @@ -973,6 +1134,11 @@ func (m *NebulaHandshakeDetails) MarshalToSizedBuffer(dAtA []byte) (int, error) _ = i var l int _ = l + if m.CertVersion != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.CertVersion)) + i-- + dAtA[i] = 0x40 + } if m.Time != 0 { i = encodeVarintNebula(dAtA, i, uint64(m.Time)) i-- @@ -1023,13 +1189,37 @@ func (m *NebulaControl) MarshalToSizedBuffer(dAtA []byte) (int, error) { _ = i var l int _ = l - if m.RelayFromIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayFromIp)) + if m.RelayFromAddr != nil { + { + size, err := m.RelayFromAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x3a + } + if m.RelayToAddr != nil { + { + size, err := m.RelayToAddr.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintNebula(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x32 + } + if m.OldRelayFromAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayFromAddr)) i-- dAtA[i] = 0x28 } - if m.RelayToIp != 0 { - i = encodeVarintNebula(dAtA, i, uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + i = encodeVarintNebula(dAtA, i, uint64(m.OldRelayToAddr)) i-- dAtA[i] = 0x20 } @@ -1084,11 +1274,11 @@ func (m *NebulaMetaDetails) Size() (n int) { } var l int _ = l - if m.VpnIp != 0 { - n += 1 + sovNebula(uint64(m.VpnIp)) + if m.OldVpnAddr != 0 { + n += 1 + sovNebula(uint64(m.OldVpnAddr)) } - if len(m.Ip4AndPorts) > 0 { - for _, e := range m.Ip4AndPorts { + if len(m.V4AddrPorts) > 0 { + for _, e := range m.V4AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } @@ -1096,30 +1286,55 @@ func (m *NebulaMetaDetails) Size() (n int) { if m.Counter != 0 { n += 1 + sovNebula(uint64(m.Counter)) } - if len(m.Ip6AndPorts) > 0 { - for _, e := range m.Ip6AndPorts { + if len(m.V6AddrPorts) > 0 { + for _, e := range m.V6AddrPorts { l = e.Size() n += 1 + l + sovNebula(uint64(l)) } } - if len(m.RelayVpnIp) > 0 { + if len(m.OldRelayVpnAddrs) > 0 { l = 0 - for _, e := range m.RelayVpnIp { + for _, e := range m.OldRelayVpnAddrs { l += sovNebula(uint64(e)) } n += 1 + sovNebula(uint64(l)) + l } + if m.VpnAddr != nil { + l = m.VpnAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if len(m.RelayVpnAddrs) > 0 { + for _, e := range m.RelayVpnAddrs { + l = e.Size() + n += 1 + l + sovNebula(uint64(l)) + } + } return n } -func (m *Ip4AndPort) Size() (n int) { +func (m *Addr) Size() (n int) { if m == nil { return 0 } var l int _ = l - if m.Ip != 0 { - n += 1 + sovNebula(uint64(m.Ip)) + if m.Hi != 0 { + n += 1 + sovNebula(uint64(m.Hi)) + } + if m.Lo != 0 { + n += 1 + sovNebula(uint64(m.Lo)) + } + return n +} + +func (m *V4AddrPort) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Addr != 0 { + n += 1 + sovNebula(uint64(m.Addr)) } if m.Port != 0 { n += 1 + sovNebula(uint64(m.Port)) @@ -1127,7 +1342,7 @@ func (m *Ip4AndPort) Size() (n int) { return n } -func (m *Ip6AndPort) Size() (n int) { +func (m *V6AddrPort) Size() (n int) { if m == nil { return 0 } @@ -1199,6 +1414,9 @@ func (m *NebulaHandshakeDetails) Size() (n int) { if m.Time != 0 { n += 1 + sovNebula(uint64(m.Time)) } + if m.CertVersion != 0 { + n += 1 + sovNebula(uint64(m.CertVersion)) + } return n } @@ -1217,11 +1435,19 @@ func (m *NebulaControl) Size() (n int) { if m.ResponderRelayIndex != 0 { n += 1 + sovNebula(uint64(m.ResponderRelayIndex)) } - if m.RelayToIp != 0 { - n += 1 + sovNebula(uint64(m.RelayToIp)) + if m.OldRelayToAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayToAddr)) + } + if m.OldRelayFromAddr != 0 { + n += 1 + sovNebula(uint64(m.OldRelayFromAddr)) } - if m.RelayFromIp != 0 { - n += 1 + sovNebula(uint64(m.RelayFromIp)) + if m.RelayToAddr != nil { + l = m.RelayToAddr.Size() + n += 1 + l + sovNebula(uint64(l)) + } + if m.RelayFromAddr != nil { + l = m.RelayFromAddr.Size() + n += 1 + l + sovNebula(uint64(l)) } return n } @@ -1368,9 +1594,9 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field VpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldVpnAddr", wireType) } - m.VpnIp = 0 + m.OldVpnAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1380,14 +1606,14 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.VpnIp |= uint32(b&0x7F) << shift + m.OldVpnAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip4AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V4AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1414,8 +1640,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip4AndPorts = append(m.Ip4AndPorts, &Ip4AndPort{}) - if err := m.Ip4AndPorts[len(m.Ip4AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V4AddrPorts = append(m.V4AddrPorts, &V4AddrPort{}) + if err := m.V4AddrPorts[len(m.V4AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1440,7 +1666,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } case 4: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip6AndPorts", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field V6AddrPorts", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -1467,8 +1693,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Ip6AndPorts = append(m.Ip6AndPorts, &Ip6AndPort{}) - if err := m.Ip6AndPorts[len(m.Ip6AndPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + m.V6AddrPorts = append(m.V6AddrPorts, &V6AddrPort{}) + if err := m.V6AddrPorts[len(m.V6AddrPorts)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -1489,7 +1715,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } else if wireType == 2 { var packedLen int for shift := uint(0); ; shift += 7 { @@ -1524,8 +1750,8 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } } elementCount = count - if elementCount != 0 && len(m.RelayVpnIp) == 0 { - m.RelayVpnIp = make([]uint32, 0, elementCount) + if elementCount != 0 && len(m.OldRelayVpnAddrs) == 0 { + m.OldRelayVpnAddrs = make([]uint32, 0, elementCount) } for iNdEx < postIndex { var v uint32 @@ -1543,11 +1769,81 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { break } } - m.RelayVpnIp = append(m.RelayVpnIp, v) + m.OldRelayVpnAddrs = append(m.OldRelayVpnAddrs, v) } } else { - return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayVpnAddrs", wireType) + } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field VpnAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.VpnAddr == nil { + m.VpnAddr = &Addr{} + } + if err := m.VpnAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayVpnAddrs", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.RelayVpnAddrs = append(m.RelayVpnAddrs, &Addr{}) + if err := m.RelayVpnAddrs[len(m.RelayVpnAddrs)-1].Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -1569,7 +1865,7 @@ func (m *NebulaMetaDetails) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { +func (m *Addr) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1592,17 +1888,17 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip4AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: Addr: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip4AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: Addr: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Ip", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Hi", wireType) } - m.Ip = 0 + m.Hi = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -1612,7 +1908,95 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.Ip |= uint32(b&0x7F) << shift + m.Hi |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Lo", wireType) + } + m.Lo = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Lo |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipNebula(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthNebula + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *V4AddrPort) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: V4AddrPort: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: V4AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Addr", wireType) + } + m.Addr = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Addr |= uint32(b&0x7F) << shift if b < 0x80 { break } @@ -1657,7 +2041,7 @@ func (m *Ip4AndPort) Unmarshal(dAtA []byte) error { } return nil } -func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { +func (m *V6AddrPort) Unmarshal(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -1680,10 +2064,10 @@ func (m *Ip6AndPort) Unmarshal(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: Ip6AndPort: wiretype end group for non-group") + return fmt.Errorf("proto: V6AddrPort: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: Ip6AndPort: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: V6AddrPort: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -2111,6 +2495,25 @@ func (m *NebulaHandshakeDetails) Unmarshal(dAtA []byte) error { break } } + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field CertVersion", wireType) + } + m.CertVersion = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.CertVersion |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) @@ -2220,9 +2623,9 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } case 4: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayToIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayToAddr", wireType) } - m.RelayToIp = 0 + m.OldRelayToAddr = 0 for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2232,16 +2635,71 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayToIp |= uint32(b&0x7F) << shift + m.OldRelayToAddr |= uint32(b&0x7F) << shift if b < 0x80 { break } } case 5: if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field RelayFromIp", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field OldRelayFromAddr", wireType) + } + m.OldRelayFromAddr = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.OldRelayFromAddr |= uint32(b&0x7F) << shift + if b < 0x80 { + break + } + } + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayToAddr", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowNebula + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayToAddr == nil { + m.RelayToAddr = &Addr{} + } + if err := m.RelayToAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 7: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field RelayFromAddr", wireType) } - m.RelayFromIp = 0 + var msglen int for shift := uint(0); ; shift += 7 { if shift >= 64 { return ErrIntOverflowNebula @@ -2251,11 +2709,28 @@ func (m *NebulaControl) Unmarshal(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - m.RelayFromIp |= uint32(b&0x7F) << shift + msglen |= int(b&0x7F) << shift if b < 0x80 { break } } + if msglen < 0 { + return ErrInvalidLengthNebula + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthNebula + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.RelayFromAddr == nil { + m.RelayFromAddr = &Addr{} + } + if err := m.RelayFromAddr.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := skipNebula(dAtA[iNdEx:]) diff --git a/nebula.proto b/nebula.proto index 88e33b7e9..ea1023348 100644 --- a/nebula.proto +++ b/nebula.proto @@ -23,19 +23,28 @@ message NebulaMeta { } message NebulaMetaDetails { - uint32 VpnIp = 1; - repeated Ip4AndPort Ip4AndPorts = 2; - repeated Ip6AndPort Ip6AndPorts = 4; - repeated uint32 RelayVpnIp = 5; + uint32 OldVpnAddr = 1 [deprecated = true]; + Addr VpnAddr = 6; + + repeated uint32 OldRelayVpnAddrs = 5 [deprecated = true]; + repeated Addr RelayVpnAddrs = 7; + + repeated V4AddrPort V4AddrPorts = 2; + repeated V6AddrPort V6AddrPorts = 4; uint32 counter = 3; } -message Ip4AndPort { - uint32 Ip = 1; +message Addr { + uint64 Hi = 1; + uint64 Lo = 2; +} + +message V4AddrPort { + uint32 Addr = 1; uint32 Port = 2; } -message Ip6AndPort { +message V6AddrPort { uint64 Hi = 1; uint64 Lo = 2; uint32 Port = 3; @@ -62,6 +71,7 @@ message NebulaHandshakeDetails { uint32 ResponderIndex = 3; uint64 Cookie = 4; uint64 Time = 5; + uint32 CertVersion = 8; // reserved for WIP multiport reserved 6, 7; } @@ -76,6 +86,10 @@ message NebulaControl { uint32 InitiatorRelayIndex = 2; uint32 ResponderRelayIndex = 3; - uint32 RelayToIp = 4; - uint32 RelayFromIp = 5; + + uint32 OldRelayToAddr = 4 [deprecated = true]; + uint32 OldRelayFromAddr = 5 [deprecated = true]; + + Addr RelayToAddr = 6; + Addr RelayFromAddr = 7; } diff --git a/outside.go b/outside.go index 6a71fe77f..d4f97b264 100644 --- a/outside.go +++ b/outside.go @@ -7,12 +7,12 @@ import ( "net/netip" "time" - "github.com/flynn/noise" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv6" + "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" - "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" ) @@ -20,24 +20,7 @@ const ( minFwPacketLen = 4 ) -// TODO: IPV6-WORK this can likely be removed now -func readOutsidePackets(f *Interface) udp.EncReader { - return func( - addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh udp.LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, - ) { - f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache) - } -} - -func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) { +func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf *LightHouseHandler, nb []byte, q int, localCache firewall.ConntrackCache) { err := h.Parse(packet) if err != nil { // TODO: best if we return this and let caller log @@ -51,7 +34,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] //l.Error("in packet ", header, packet[HeaderLen:]) if ip.IsValid() { - if f.myVpnNet.Contains(ip.Addr()) { + _, found := f.myVpnNetworksTable.Lookup(ip.Addr()) + if found { if f.l.Level >= logrus.DebugLevel { f.l.WithField("udpAddr", ip).Debug("Refusing to process double encrypted packet") } @@ -108,7 +92,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] if !ok { // The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing // its internal mapping. This should never happen. - hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") + hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnAddrs": hostinfo.vpnAddrs, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index") return } @@ -120,9 +104,9 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return case ForwardingType: // Find the target HostInfo relay object - targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp) + targetHI, targetRelay, err := f.hostMap.QueryVpnAddrsRelayFor(hostinfo.vpnAddrs, relay.PeerAddr) if err != nil { - hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip") + hostinfo.logger(f.l).WithField("relayTo", relay.PeerAddr).WithError(err).WithField("hostinfo.vpnAddrs", hostinfo.vpnAddrs).Info("Failed to find target host info by ip") return } @@ -138,7 +122,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal") } } else { - hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") + hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerAddr, "relayFrom": hostinfo.vpnAddrs[0], "targetRelayState": targetRelay.State}).Info("Unexpected target relay state") return } } @@ -161,7 +145,7 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] return } - lhf(ip, hostinfo.vpnIp, d) + lhf.HandleRequest(ip, hostinfo.vpnAddrs, d, f) // Fallthrough to the bottom to record incoming traffic @@ -228,14 +212,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] Error("Failed to decrypt Control packet") return } - m := &NebulaControl{} - err = m.Unmarshal(d) - if err != nil { - hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message") - break - } - f.relayManager.HandleControlMsg(hostinfo, m, f) + f.relayManager.HandleControlMsg(hostinfo, d, f) default: f.messageMetrics.Rx(h.Type, h.Subtype, 1) @@ -252,8 +230,8 @@ func (f *Interface) readOutsidePackets(ip netip.AddrPort, via *ViaSender, out [] func (f *Interface) closeTunnel(hostInfo *HostInfo) { final := f.hostMap.DeleteHostInfo(hostInfo) if final { - // We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage - f.lightHouse.DeleteVpnIp(hostInfo.vpnIp) + // We no longer have any tunnels with this vpn addr, clear learned lighthouse state to lower memory usage + f.lightHouse.DeleteVpnAddrs(hostInfo.vpnAddrs) } } @@ -262,25 +240,26 @@ func (f *Interface) sendCloseTunnel(h *HostInfo) { f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu)) } -func (f *Interface) handleHostRoaming(hostinfo *HostInfo, ip netip.AddrPort) { - if ip.IsValid() && hostinfo.remote != ip { - if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, ip.Addr()) { - hostinfo.logger(f.l).WithField("newAddr", ip).Debug("lighthouse.remote_allow_list denied roaming") +func (f *Interface) handleHostRoaming(hostinfo *HostInfo, vpnAddr netip.AddrPort) { + if vpnAddr.IsValid() && hostinfo.remote != vpnAddr { + //TODO: this is weird now that we can have multiple vpn addrs + if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnAddrs[0], vpnAddr.Addr()) { + hostinfo.logger(f.l).WithField("newAddr", vpnAddr).Debug("lighthouse.remote_allow_list denied roaming") return } - if !hostinfo.lastRoam.IsZero() && ip == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { + if !hostinfo.lastRoam.IsZero() && vpnAddr == hostinfo.lastRoamRemote && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second { if f.l.Level >= logrus.DebugLevel { - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds) } return } - hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", ip). + hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", vpnAddr). Info("Host roamed to new udp ip/port.") hostinfo.lastRoam = time.Now() hostinfo.lastRoamRemote = hostinfo.remote - hostinfo.SetRemote(ip) + hostinfo.SetRemote(vpnAddr) } } @@ -302,14 +281,114 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { - // Do we at least have an ipv4 header worth of data? - if len(data) < ipv4.HeaderLen { - return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen) + if len(data) < 1 { + return errors.New("packet too short") + } + + version := int((data[0] >> 4) & 0x0f) + switch version { + case ipv4.Version: + return parseV4(data, incoming, fp) + case ipv6.Version: + return parseV6(data, incoming, fp) } + return fmt.Errorf("packet is an unknown ip version: %v", version) +} - // Is it an ipv4 packet? - if int((data[0]>>4)&0x0f) != 4 { - return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f)) +func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { + dataLen := len(data) + if dataLen < ipv6.HeaderLen { + return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen) + } + + if incoming { + fp.RemoteAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[24:40]) + } else { + fp.LocalAddr, _ = netip.AddrFromSlice(data[8:24]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) + } + + //TODO: whats a reasonable number of extension headers to attempt to parse? + //https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html + protoAt := 6 + offset := 40 + for i := 0; i < 24; i++ { + if dataLen < offset { + break + } + + proto := layers.IPProtocol(data[protoAt]) + //fmt.Println(proto, protoAt) + switch proto { + case layers.IPProtocolICMPv6: + //TODO: we need a new protocol in config language "icmpv6" + fp.Protocol = uint8(proto) + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = false + return nil + + case layers.IPProtocolTCP: + if dataLen < offset+4 { + return fmt.Errorf("ipv6 packet was too small") + } + fp.Protocol = uint8(proto) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } + fp.Fragment = false + return nil + + case layers.IPProtocolUDP: + if dataLen < offset+4 { + return fmt.Errorf("ipv6 packet was too small") + } + fp.Protocol = uint8(proto) + if incoming { + fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } else { + fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) + fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + } + fp.Fragment = false + return nil + + case layers.IPProtocolIPv6Fragment: + //TODO: can we determine the protocol? + fp.RemotePort = 0 + fp.LocalPort = 0 + fp.Fragment = true + return nil + + default: + if dataLen < offset+1 { + break + } + + next := int(data[offset+1]) * 8 + if next == 0 { + // each extension is at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next + } + } + + return fmt.Errorf("could not find payload in ipv6 packet") +} + +func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { + // Do we at least have an ipv4 header worth of data? + if len(data) < ipv4.HeaderLen { + return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen) } // Adjust our start position based on the advertised ip header length @@ -317,7 +396,7 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { // Well formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("packet had an invalid header length: %v", ihl) + return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl) } // Check if this is the second or further fragment of a fragmented packet. @@ -333,14 +412,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl) + return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl) } // Firewall packets are locally oriented if incoming { - //TODO: IPV6-WORK - fp.RemoteIP, _ = netip.AddrFromSlice(data[12:16]) - fp.LocalIP, _ = netip.AddrFromSlice(data[16:20]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -349,9 +427,8 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4]) } } else { - //TODO: IPV6-WORK - fp.LocalIP, _ = netip.AddrFromSlice(data[12:16]) - fp.RemoteIP, _ = netip.AddrFromSlice(data[16:20]) + fp.LocalAddr, _ = netip.AddrFromSlice(data[12:16]) + fp.RemoteAddr, _ = netip.AddrFromSlice(data[16:20]) if fp.Fragment || fp.Protocol == firewall.ProtoICMP { fp.RemotePort = 0 fp.LocalPort = 0 @@ -492,27 +569,3 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N f.outside.WriteTo(msg, endpoint) } */ - -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) { - pk := h.PeerStatic() - - if pk == nil { - return nil, errors.New("no peer static key was present") - } - - if rawCertBytes == nil { - return nil, errors.New("provided payload was empty") - } - - c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %w", err) - } - - cc, err := caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return nil, fmt.Errorf("certificate validation failed: %w", err) - } - - return cc, nil -} diff --git a/outside_test.go b/outside_test.go index f9d4bfa48..cbe622345 100644 --- a/outside_test.go +++ b/outside_test.go @@ -5,6 +5,9 @@ import ( "net/netip" "testing" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" "golang.org/x/net/ipv4" @@ -13,9 +16,15 @@ import ( func Test_newPacket(t *testing.T) { p := &firewall.Packet{} - // length fail - err := newPacket([]byte{0, 1}, true, p) - assert.EqualError(t, err, "packet is less than 20 bytes") + // length fails + err := newPacket([]byte{}, true, p) + assert.EqualError(t, err, "packet too short") + + err = newPacket([]byte{0x40}, true, p) + assert.EqualError(t, err, "ipv4 packet is less than 20 bytes") + + err = newPacket([]byte{0x60}, true, p) + assert.EqualError(t, err, "ipv6 packet is less than 20 bytes") // length fail with ip options h := ipv4.Header{ @@ -29,15 +38,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - assert.EqualError(t, err, "packet is less than 28 bytes, ip header len: 24") + assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24") // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is not ipv4, type: 0") + assert.EqualError(t, err, "packet is an unknown ip version: 0") // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet had an invalid header length: 8") + assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8") // account for variable ip header length - incoming h = ipv4.Header{ @@ -55,8 +64,8 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1")) assert.Equal(t, p.RemotePort, uint16(3)) assert.Equal(t, p.LocalPort, uint16(4)) @@ -76,8 +85,60 @@ func Test_newPacket(t *testing.T) { assert.Nil(t, err) assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalIP, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemoteIP, netip.MustParseAddr("10.0.0.2")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2")) assert.Equal(t, p.RemotePort, uint16(6)) assert.Equal(t, p.LocalPort, uint16(5)) } + +func Test_newPacket_v6(t *testing.T) { + p := &firewall.Packet{} + + ip := layers.IPv6{ + Version: 6, + NextHeader: firewall.ProtoUDP, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + err := udp.SetNetworkLayerForChecksum(&ip) + if err != nil { + panic(err) + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + if err != nil { + panic(err) + } + b := buffer.Bytes() + + //test incoming + err = newPacket(b, true, p) + + assert.Nil(t, err) + assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2")) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1")) + assert.Equal(t, p.RemotePort, uint16(36123)) + assert.Equal(t, p.LocalPort, uint16(22)) + + //test outgoing + err = newPacket(b, false, p) + + assert.Nil(t, err) + assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) + assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2")) + assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1")) + assert.Equal(t, p.LocalPort, uint16(36123)) + assert.Equal(t, p.RemotePort, uint16(22)) +} diff --git a/overlay/device.go b/overlay/device.go index 50ad6ad5b..da8cbe92b 100644 --- a/overlay/device.go +++ b/overlay/device.go @@ -8,7 +8,7 @@ import ( type Device interface { io.ReadWriteCloser Activate() error - Cidr() netip.Prefix + Networks() []netip.Prefix Name() string RouteFor(netip.Addr) netip.Addr NewMultiQueueReader() (io.ReadWriteCloser, error) diff --git a/overlay/route.go b/overlay/route.go index 8ccc9943c..687cc11b8 100644 --- a/overlay/route.go +++ b/overlay/route.go @@ -61,7 +61,7 @@ func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table return routeTree, nil } -func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.routes") @@ -117,12 +117,20 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.routes failed to parse: %v", i+1, err) } - if !network.Contains(r.Cidr.Addr()) || r.Cidr.Bits() < network.Bits() { + found := false + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) && r.Cidr.Bits() >= network.Bits() { + found = true + break + } + } + + if !found { return nil, fmt.Errorf( - "entry %v.route in tun.routes is not contained within the network attached to the certificate; route: %v, network: %v", + "entry %v.route in tun.routes is not contained within the configured vpn networks; route: %v, networks: %v", i+1, r.Cidr.String(), - network.String(), + networks, ) } @@ -132,7 +140,7 @@ func parseRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return routes, nil } -func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { +func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) { var err error r := c.Get("tun.unsafe_routes") @@ -229,13 +237,15 @@ func parseUnsafeRoutes(c *config.C, network netip.Prefix) ([]Route, error) { return nil, fmt.Errorf("entry %v.route in tun.unsafe_routes failed to parse: %v", i+1, err) } - if network.Contains(r.Cidr.Addr()) { - return nil, fmt.Errorf( - "entry %v.route in tun.unsafe_routes is contained within the network attached to the certificate; route: %v, network: %v", - i+1, - r.Cidr.String(), - network.String(), - ) + for _, network := range networks { + if network.Contains(r.Cidr.Addr()) { + return nil, fmt.Errorf( + "entry %v.route in tun.unsafe_routes is contained within the configured vpn networks; route: %v, network: %v", + i+1, + r.Cidr.String(), + network.String(), + ) + } } routes[i] = r diff --git a/overlay/route_test.go b/overlay/route_test.go index d7913894b..c60e4c24b 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -17,76 +17,82 @@ func Test_parseRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseRoutes(c, n) + routes, err := parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 1.0.0.0/8, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the network attached to the certificate; route: 10.0.1.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + + // Not in multiple ranges + c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} + routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) + assert.Nil(t, routes) + assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ map[interface{}]interface{}{"mtu": "9000", "route": "10.0.0.0/29"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} - routes, err = parseRoutes(c, n) + routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 2) @@ -116,31 +122,31 @@ func Test_parseUnsafeRoutes(t *testing.T) { assert.NoError(t, err) // test no routes config - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 0) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") @@ -149,68 +155,68 @@ func Test_parseUnsafeRoutes(t *testing.T) { 127, false, nil, 1.0, []string{"1", "2"}, } { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the network attached to the certificate; route: 10.0.0.0/24, network: 10.0.0.0/24") + assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Nil(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) assert.Equal(t, 0, routes[0].MTU) // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") @@ -221,7 +227,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32", "install": 1}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} - routes, err = parseUnsafeRoutes(c, n) + routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, err) assert.Len(t, routes, 4) @@ -260,7 +266,7 @@ func Test_makeRouteTree(t *testing.T) { map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} - routes, err := parseUnsafeRoutes(c, n) + routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) assert.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) diff --git a/overlay/tun.go b/overlay/tun.go index 12460da1f..4a6377d2a 100644 --- a/overlay/tun.go +++ b/overlay/tun.go @@ -11,36 +11,36 @@ import ( const DefaultMTU = 1300 // TODO: We may be able to remove routines -type DeviceFactory func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) +type DeviceFactory func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) -func NewDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { +func NewDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { switch { case c.GetBool("tun.disabled", false): - tun := newDisabledTun(tunCidr, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) + tun := newDisabledTun(vpnNetworks, c.GetInt("tun.tx_queue", 500), c.GetBool("stats.message_metrics", false), l) return tun, nil default: - return newTun(c, l, tunCidr, routines > 1) + return newTun(c, l, vpnNetworks, routines > 1) } } func NewFdDeviceFromConfig(fd *int) DeviceFactory { - return func(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return newTunFromFd(c, l, *fd, tunCidr) + return func(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return newTunFromFd(c, l, *fd, vpnNetworks) } } -func getAllRoutesFromConfig(c *config.C, cidr netip.Prefix, initial bool) (bool, []Route, error) { +func getAllRoutesFromConfig(c *config.C, vpnNetworks []netip.Prefix, initial bool) (bool, []Route, error) { if !initial && !c.HasChanged("tun.routes") && !c.HasChanged("tun.unsafe_routes") { return false, nil, nil } - routes, err := parseRoutes(c, cidr) + routes, err := parseRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.routes", nil, err) } - unsafeRoutes, err := parseUnsafeRoutes(c, cidr) + unsafeRoutes, err := parseUnsafeRoutes(c, vpnNetworks) if err != nil { return true, nil, util.NewContextualError("Could not parse tun.unsafe_routes", nil, err) } diff --git a/overlay/tun_android.go b/overlay/tun_android.go index 98ad9b408..72a656500 100644 --- a/overlay/tun_android.go +++ b/overlay/tun_android.go @@ -18,14 +18,14 @@ import ( type tun struct { io.ReadWriteCloser - fd int - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + fd int + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { // XXX Android returns an fd in non-blocking mode which is necessary for shutdown to work properly. // Be sure not to call file.Fd() as it will set the fd to blocking mode. file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") @@ -33,7 +33,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix t := &tun{ ReadWriteCloser: file, fd: deviceFd, - cidr: cidr, + vpnNetworks: vpnNetworks, l: l, } @@ -52,7 +52,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in Android") } @@ -66,7 +66,7 @@ func (t tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -86,8 +86,8 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_darwin.go b/overlay/tun_darwin.go index 0b573e6b3..1cff2144c 100644 --- a/overlay/tun_darwin.go +++ b/overlay/tun_darwin.go @@ -24,56 +24,62 @@ import ( type tun struct { io.ReadWriteCloser - Device string - cidr netip.Prefix - DefaultMTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - linkAddr *netroute.LinkAddr - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + DefaultMTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + linkAddr *netroute.LinkAddr + l *logrus.Logger // cache out buffer since we need to prepend 4 bytes for tun metadata out []byte } -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type ifReq struct { - Name [16]byte + Name [unix.IFNAMSIZ]byte Flags uint16 pad [8]byte } -var sockaddrCtlSize uintptr = 32 - const ( - _SYSPROTO_CONTROL = 2 //define SYSPROTO_CONTROL 2 /* kernel control protocol */ - _AF_SYS_CONTROL = 2 //#define AF_SYS_CONTROL 2 /* corresponding sub address type */ - _PF_SYSTEM = unix.AF_SYSTEM //#define PF_SYSTEM AF_SYSTEM - _CTLIOCGINFO = 3227799043 //#define CTLIOCGINFO _IOWR('N', 3, struct ctl_info) - utunControlName = "com.apple.net.utun_control" + _SIOCAIFADDR_IN6 = 2155899162 + _UTUN_OPT_IFNAME = 2 + _IN6_IFF_NODAD = 0x0020 + _IN6_IFF_SECURED = 0x0400 + utunControlName = "com.apple.net.utun_control" ) -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 pad [8]byte } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +type addrLifetime struct { + Expire float64 + Preferred float64 + Vltime uint32 + Pltime uint32 +} + +type ifreqAlias4 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet4 + DstAddr unix.RawSockaddrInet4 + MaskAddr unix.RawSockaddrInet4 +} + +type ifreqAlias6 struct { + Name [unix.IFNAMSIZ]byte + Addr unix.RawSockaddrInet6 + DstAddr unix.RawSockaddrInet6 + PrefixMask unix.RawSockaddrInet6 + Flags uint32 + Lifetime addrLifetime +} + +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { name := c.GetString("tun.dev", "") ifIndex := -1 if name != "" && name != "utun" { @@ -86,66 +92,41 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } } - fd, err := unix.Socket(_PF_SYSTEM, unix.SOCK_DGRAM, _SYSPROTO_CONTROL) + fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, unix.AF_SYS_CONTROL) if err != nil { return nil, fmt.Errorf("system socket: %v", err) } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} + var ctlInfo = &unix.CtlInfo{} + copy(ctlInfo.Name[:], utunControlName) - copy(ctlInfo.ctlName[:], utunControlName) - - err = ioctl(uintptr(fd), uintptr(_CTLIOCGINFO), uintptr(unsafe.Pointer(ctlInfo))) + err = unix.IoctlCtlInfo(fd, ctlInfo) if err != nil { return nil, fmt.Errorf("CTLIOCGINFO: %v", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: _AF_SYS_CONTROL, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, - } - - _, _, errno := unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(unsafe.Pointer(&sc)), - sockaddrCtlSize, - ) - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + err = unix.Connect(fd, &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, + }) + if err != nil { + return nil, fmt.Errorf("SYS_CONNECT: %v", err) } - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(len(ifName.name)) - _, _, errno = syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), - 2, // SYSPROTO_CONTROL - 2, // UTUN_OPT_IFNAME - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) - if errno != 0 { - return nil, fmt.Errorf("SYS_GETSOCKOPT: %v", errno) + name, err = unix.GetsockoptString(fd, unix.AF_SYS_CONTROL, _UTUN_OPT_IFNAME) + if err != nil { + return nil, fmt.Errorf("failed to retrieve tun name: %w", err) } - name = string(ifName.name[:ifNameSize-1]) - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { return nil, fmt.Errorf("SetNonblock: %v", err) } - file := os.NewFile(uintptr(fd), "") - t := &tun{ - ReadWriteCloser: file, + ReadWriteCloser: os.NewFile(uintptr(fd), ""), Device: name, - cidr: cidr, + vpnNetworks: vpnNetworks, DefaultMTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -172,7 +153,7 @@ func (t *tun) deviceBytes() (o [16]byte) { return } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in Darwin") } @@ -186,16 +167,6 @@ func (t *tun) Close() error { func (t *tun) Activate() error { devName := t.deviceBytes() - var addr, mask [4]byte - - if !t.cidr.Addr().Is4() { - //TODO: IPV6-WORK - panic("need ipv6") - } - - addr = t.cidr.Addr().As4() - copy(mask[:], prefixToMask(t.cidr)) - s, err := unix.Socket( unix.AF_INET, unix.SOCK_DGRAM, @@ -208,66 +179,18 @@ func (t *tun) Activate() error { fd := uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(fd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(fd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - - // Set the device name - ifrf := ifReq{Name: devName} - if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to set tun device name: %s", err) - } - // Set the MTU on the device ifm := ifreqMTU{Name: devName, MTU: int32(t.DefaultMTU)} if err = ioctl(fd, unix.SIOCSIFMTU, uintptr(unsafe.Pointer(&ifm))); err != nil { return fmt.Errorf("failed to set tun mtu: %v", err) } - /* - // Set the transmit queue length - ifrq := ifreqQLEN{Name: devName, Value: int32(t.TXQueueLen)} - if err = ioctl(fd, unix.SIOCSIFTXQLEN, uintptr(unsafe.Pointer(&ifrq))); err != nil { - // If we can't set the queue length nebula will still work but it may lead to packet loss - l.WithError(err).Error("Failed to set tun tx queue length") - } - */ - - // Bring up the interface - ifrf.Flags = ifrf.Flags | unix.IFF_UP - if err = ioctl(fd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { - return fmt.Errorf("failed to bring the tun device up: %s", err) - } - - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + // Get the device flags + ifrf := ifReq{Name: devName} + if err = ioctl(fd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { + return fmt.Errorf("failed to get tun flags: %s", err) } - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} linkAddr, err := getLinkAddr(t.Device) if err != nil { return err @@ -277,14 +200,18 @@ func (t *tun) Activate() error { } t.linkAddr = linkAddr - copy(routeAddr.IP[:], addr[:]) - copy(maskAddr.IP[:], mask[:]) - err = addRoute(routeSock, routeAddr, maskAddr, linkAddr) - if err != nil { - if errors.Is(err, unix.EEXIST) { - err = fmt.Errorf("unable to add tun route, identical route already exists: %s", t.cidr) + for _, network := range t.vpnNetworks { + if network.Addr().Is4() { + err = t.activate4(network) + if err != nil { + return err + } + } else { + err = t.activate6(network) + if err != nil { + return err + } } - return err } // Run the interface @@ -297,8 +224,89 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) activate4(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias4{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + DstAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: network.Addr().As4(), + }, + MaskAddr: unix.RawSockaddrInet4{ + Len: unix.SizeofSockaddrInet4, + Family: unix.AF_INET, + Addr: prefixToMask(network).As4(), + }, + } + + if err := ioctl(uintptr(s), unix.SIOCAIFADDR, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun v4 address: %s", err) + } + + err = addRoute(network, t.linkAddr) + if err != nil { + return err + } + + return nil +} + +func (t *tun) activate6(network netip.Prefix) error { + s, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + unix.IPPROTO_IP, + ) + if err != nil { + return err + } + defer unix.Close(s) + + ifr := ifreqAlias6{ + Name: t.deviceBytes(), + Addr: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: network.Addr().As16(), + }, + PrefixMask: unix.RawSockaddrInet6{ + Len: unix.SizeofSockaddrInet6, + Family: unix.AF_INET6, + Addr: prefixToMask(network).As16(), + }, + Lifetime: addrLifetime{ + // never expires + Vltime: 0xffffffff, + Pltime: 0xffffffff, + }, + //TODO: should we disable DAD (duplicate address detection) and mark this as a secured address? + Flags: _IN6_IFF_NODAD, + } + + if err := ioctl(uintptr(s), _SIOCAIFADDR_IN6, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set tun address: %s", err) + } + + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -371,38 +379,15 @@ func getLinkAddr(name string) (*netroute.LinkAddr, error) { } func (t *tun) addRoutes(logErrors bool) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} routes := *t.Routes.Load() + for _, r := range routes { if !r.Via.IsValid() || !r.Install { // We don't allow route MTUs so only install routes with a via continue } - if !r.Cidr.Addr().Is4() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - //TODO: we could avoid the copy - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := addRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := addRoute(r.Cidr, t.linkAddr) if err != nil { if errors.Is(err, unix.EEXIST) { t.l.WithField("route", r.Cidr). @@ -424,36 +409,12 @@ func (t *tun) addRoutes(logErrors bool) error { } func (t *tun) removeRoutes(routes []Route) error { - routeSock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) - if err != nil { - return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) - } - - defer func() { - unix.Shutdown(routeSock, unix.SHUT_RDWR) - err := unix.Close(routeSock) - if err != nil { - t.l.WithError(err).Error("failed to close AF_ROUTE socket") - } - }() - - routeAddr := &netroute.Inet4Addr{} - maskAddr := &netroute.Inet4Addr{} - for _, r := range routes { if !r.Install { continue } - if r.Cidr.Addr().Is6() { - //TODO: implement ipv6 - panic("Cant handle ipv6 routes yet") - } - - routeAddr.IP = r.Cidr.Addr().As4() - copy(maskAddr.IP[:], prefixToMask(r.Cidr)) - - err := delRoute(routeSock, routeAddr, maskAddr, t.linkAddr) + err := delRoute(r.Cidr, t.linkAddr) if err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") } else { @@ -463,23 +424,39 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func addRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := &netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_ADD, Flags: unix.RTF_UP, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } + _, err = unix.Write(sock, data[:]) if err != nil { return fmt.Errorf("failed to write route.RouteMessage to socket: %w", err) @@ -488,19 +465,34 @@ func addRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) return nil } -func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) error { - r := netroute.RouteMessage{ +func delRoute(prefix netip.Prefix, gateway netroute.Addr) error { + sock, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return fmt.Errorf("unable to create AF_ROUTE socket: %v", err) + } + defer unix.Close(sock) + + route := netroute.RouteMessage{ Version: unix.RTM_VERSION, Type: unix.RTM_DELETE, Seq: 1, - Addrs: []netroute.Addr{ - unix.RTAX_DST: addr, - unix.RTAX_GATEWAY: link, - unix.RTAX_NETMASK: mask, - }, } - data, err := r.Marshal() + if prefix.Addr().Is4() { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet4Addr{IP: prefix.Masked().Addr().As4()}, + unix.RTAX_NETMASK: &netroute.Inet4Addr{IP: prefixToMask(prefix).As4()}, + unix.RTAX_GATEWAY: gateway, + } + } else { + route.Addrs = []netroute.Addr{ + unix.RTAX_DST: &netroute.Inet6Addr{IP: prefix.Masked().Addr().As16()}, + unix.RTAX_NETMASK: &netroute.Inet6Addr{IP: prefixToMask(prefix).As16()}, + unix.RTAX_GATEWAY: gateway, + } + } + + data, err := route.Marshal() if err != nil { return fmt.Errorf("failed to create route.RouteMessage: %w", err) } @@ -513,7 +505,6 @@ func delRoute(sock int, addr, mask *netroute.Inet4Addr, link *netroute.LinkAddr) } func (t *tun) Read(to []byte) (int, error) { - buf := make([]byte, len(to)+4) n, err := t.ReadWriteCloser.Read(buf) @@ -551,8 +542,8 @@ func (t *tun) Write(from []byte) (int, error) { return n - 4, err } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -563,10 +554,12 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) { return nil, fmt.Errorf("TODO: multiqueue not implemented for darwin") } -func prefixToMask(prefix netip.Prefix) []byte { +func prefixToMask(prefix netip.Prefix) netip.Addr { pLen := 128 if prefix.Addr().Is4() { pLen = 32 } - return net.CIDRMask(prefix.Bits(), pLen) + + addr, _ := netip.AddrFromSlice(net.CIDRMask(prefix.Bits(), pLen)) + return addr } diff --git a/overlay/tun_disabled.go b/overlay/tun_disabled.go index 130f8f99f..cfbf17d97 100644 --- a/overlay/tun_disabled.go +++ b/overlay/tun_disabled.go @@ -12,8 +12,8 @@ import ( ) type disabledTun struct { - read chan []byte - cidr netip.Prefix + read chan []byte + vpnNetworks []netip.Prefix // Track these metrics since we don't have the tun device to do it for us tx metrics.Counter @@ -21,11 +21,11 @@ type disabledTun struct { l *logrus.Logger } -func newDisabledTun(cidr netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { +func newDisabledTun(vpnNetworks []netip.Prefix, queueLen int, metricsEnabled bool, l *logrus.Logger) *disabledTun { tun := &disabledTun{ - cidr: cidr, - read: make(chan []byte, queueLen), - l: l, + vpnNetworks: vpnNetworks, + read: make(chan []byte, queueLen), + l: l, } if metricsEnabled { @@ -47,8 +47,8 @@ func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr { return netip.Addr{} } -func (t *disabledTun) Cidr() netip.Prefix { - return t.cidr +func (t *disabledTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (*disabledTun) Name() string { diff --git a/overlay/tun_freebsd.go b/overlay/tun_freebsd.go index bdfeb5802..69690e948 100644 --- a/overlay/tun_freebsd.go +++ b/overlay/tun_freebsd.go @@ -46,12 +46,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -78,11 +78,11 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in FreeBSD") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open existing tun device var file *os.File var err error @@ -150,7 +150,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -170,16 +170,16 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), "-interface", t.Device) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), "-interface", t.Device) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -195,8 +195,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -237,8 +247,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_ios.go b/overlay/tun_ios.go index 20981f08c..e99d44718 100644 --- a/overlay/tun_ios.go +++ b/overlay/tun_ios.go @@ -21,20 +21,20 @@ import ( type tun struct { io.ReadWriteCloser - cidr netip.Prefix - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + vpnNetworks []netip.Prefix + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger } -func newTun(_ *config.C, _ *logrus.Logger, _ netip.Prefix, _ bool) (*tun, error) { +func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, error) { return nil, fmt.Errorf("newTun not supported in iOS") } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/tun") t := &tun{ - cidr: cidr, + vpnNetworks: vpnNetworks, ReadWriteCloser: &tunReadCloser{f: file}, l: l, } @@ -59,7 +59,7 @@ func (t *tun) Activate() error { } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -142,8 +142,8 @@ func (tr *tunReadCloser) Close() error { return tr.f.Close() } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_linux.go b/overlay/tun_linux.go index 0e7e20d41..d8501728f 100644 --- a/overlay/tun_linux.go +++ b/overlay/tun_linux.go @@ -11,6 +11,7 @@ import ( "os" "strings" "sync/atomic" + "time" "unsafe" "github.com/gaissmai/bart" @@ -25,7 +26,7 @@ type tun struct { io.ReadWriteCloser fd int Device string - cidr netip.Prefix + vpnNetworks []netip.Prefix MaxMTU int DefaultMTU int TXQueueLen int @@ -40,18 +41,16 @@ type tun struct { l *logrus.Logger } +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks +} + type ifReq struct { Name [16]byte Flags uint16 pad [8]byte } -type ifreqAddr struct { - Name [16]byte - Addr unix.RawSockaddrInet4 - pad [8]byte -} - type ifreqMTU struct { Name [16]byte MTU int32 @@ -64,10 +63,10 @@ type ifreqQLEN struct { pad [8]byte } -func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix) (*tun, error) { +func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, vpnNetworks []netip.Prefix) (*tun, error) { file := os.NewFile(uintptr(deviceFd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -77,7 +76,7 @@ func newTunFromFd(c *config.C, l *logrus.Logger, deviceFd int, cidr netip.Prefix return t, nil } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, multiqueue bool) (*tun, error) { fd, err := unix.Open("/dev/net/tun", os.O_RDWR, 0) if err != nil { // If /dev/net/tun doesn't exist, try to create it (will happen in docker) @@ -112,7 +111,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( name := strings.Trim(string(req.Name[:]), "\x00") file := os.NewFile(uintptr(fd), "/dev/net/tun") - t, err := newTunGeneric(c, l, file, cidr) + t, err := newTunGeneric(c, l, file, vpnNetworks) if err != nil { return nil, err } @@ -122,11 +121,11 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) ( return t, nil } -func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Prefix) (*tun, error) { +func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, vpnNetworks []netip.Prefix) (*tun, error) { t := &tun{ ReadWriteCloser: file, fd: int(file.Fd()), - cidr: cidr, + vpnNetworks: vpnNetworks, TXQueueLen: c.GetInt("tun.tx_queue", 500), useSystemRoutes: c.GetBool("tun.use_system_route_table", false), l: l, @@ -148,7 +147,7 @@ func newTunGeneric(c *config.C, l *logrus.Logger, file *os.File, cidr netip.Pref } func (t *tun) reload(c *config.C, initial bool) error { - routeChange, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + routeChange, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -190,11 +189,13 @@ func (t *tun) reload(c *config.C, initial bool) error { } if oldDefaultMTU != newDefaultMTU { - err := t.setDefaultRoute() - if err != nil { - t.l.Warn(err) - } else { - t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + for i := range t.vpnNetworks { + err := t.setDefaultRoute(t.vpnNetworks[i]) + if err != nil { + t.l.Warn(err) + } else { + t.l.Infof("Set default MTU to %v was %v", t.DefaultMTU, oldDefaultMTU) + } } } @@ -237,10 +238,10 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { func (t *tun) Write(b []byte) (int, error) { var nn int - max := len(b) + maximum := len(b) for { - n, err := unix.Write(t.fd, b[nn:max]) + n, err := unix.Write(t.fd, b[nn:maximum]) if n > 0 { nn += n } @@ -265,6 +266,58 @@ func (t *tun) deviceBytes() (o [16]byte) { return } +func hasNetlinkAddr(al []*netlink.Addr, x netlink.Addr) bool { + for i := range al { + if al[i].Equal(x) { + return true + } + } + return false +} + +// addIPs uses netlink to add all addresses that don't exist, then it removes ones that should not be there +func (t *tun) addIPs(link netlink.Link) error { + newAddrs := make([]*netlink.Addr, len(t.vpnNetworks)) + for i := range t.vpnNetworks { + newAddrs[i] = &netlink.Addr{ + IPNet: &net.IPNet{ + IP: t.vpnNetworks[i].Addr().AsSlice(), + Mask: net.CIDRMask(t.vpnNetworks[i].Bits(), t.vpnNetworks[i].Addr().BitLen()), + }, + Label: t.vpnNetworks[i].Addr().Zone(), + } + } + + //add all new addresses + for i := range newAddrs { + //todo do we want to stack errors and try as many ops as possible? + //AddrReplace still adds new IPs, but if their properties change it will change them as well + if err := netlink.AddrReplace(link, newAddrs[i]); err != nil { + return err + } + } + + //iterate over remainder, remove whoever shouldn't be there + al, err := netlink.AddrList(link, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed to get tun address list: %s", err) + } + + for i := range al { + if hasNetlinkAddr(newAddrs, al[i]) { + continue + } + err = netlink.AddrDel(link, &al[i]) + if err != nil { + t.l.WithError(err).Error("failed to remove address from tun address list") + } else { + t.l.WithField("removed", al[i].String()).Info("removed address not listed in cert(s)") + } + } + + return nil +} + func (t *tun) Activate() error { devName := t.deviceBytes() @@ -272,15 +325,8 @@ func (t *tun) Activate() error { t.watchRoutes() } - var addr, mask [4]byte - - //TODO: IPV6-WORK - addr = t.cidr.Addr().As4() - tmask := net.CIDRMask(t.cidr.Bits(), 32) - copy(mask[:], tmask) - s, err := unix.Socket( - unix.AF_INET, + unix.AF_INET, //because everything we use t.ioctlFd for is address family independent, this is fine unix.SOCK_DGRAM, unix.IPPROTO_IP, ) @@ -289,31 +335,19 @@ func (t *tun) Activate() error { } t.ioctlFd = uintptr(s) - ifra := ifreqAddr{ - Name: devName, - Addr: unix.RawSockaddrInet4{ - Family: unix.AF_INET, - Addr: addr, - }, - } - - // Set the device ip address - if err = ioctl(t.ioctlFd, unix.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun address: %s", err) - } - - // Set the device network - ifra.Addr.Addr = mask - if err = ioctl(t.ioctlFd, unix.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { - return fmt.Errorf("failed to set tun netmask: %s", err) - } - // Set the device name ifrf := ifReq{Name: devName} if err = ioctl(t.ioctlFd, unix.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to set tun device name: %s", err) } + link, err := netlink.LinkByName(t.Device) + if err != nil { + return fmt.Errorf("failed to get tun device link: %s", err) + } + + t.deviceIndex = link.Attrs().Index + // Setup our default MTU t.setMTU() @@ -324,20 +358,21 @@ func (t *tun) Activate() error { t.l.WithError(err).Error("Failed to set tun tx queue length") } + if err = t.addIPs(link); err != nil { + return err + } + // Bring up the interface ifrf.Flags = ifrf.Flags | unix.IFF_UP if err = ioctl(t.ioctlFd, unix.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifrf))); err != nil { return fmt.Errorf("failed to bring the tun device up: %s", err) } - link, err := netlink.LinkByName(t.Device) - if err != nil { - return fmt.Errorf("failed to get tun device link: %s", err) - } - t.deviceIndex = link.Attrs().Index - - if err = t.setDefaultRoute(); err != nil { - return err + //set route MTU + for i := range t.vpnNetworks { + if err = t.setDefaultRoute(t.vpnNetworks[i]); err != nil { + return fmt.Errorf("failed to set default route MTU: %w", err) + } } // Set the routes @@ -363,12 +398,10 @@ func (t *tun) setMTU() { } } -func (t *tun) setDefaultRoute() error { - // Default route - +func (t *tun) setDefaultRoute(cidr netip.Prefix) error { dr := &net.IPNet{ - IP: t.cidr.Masked().Addr().AsSlice(), - Mask: net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen()), + IP: cidr.Masked().Addr().AsSlice(), + Mask: net.CIDRMask(cidr.Bits(), cidr.Addr().BitLen()), } nr := netlink.Route{ @@ -377,14 +410,27 @@ func (t *tun) setDefaultRoute() error { MTU: t.DefaultMTU, AdvMSS: t.advMSS(Route{}), Scope: unix.RT_SCOPE_LINK, - Src: net.IP(t.cidr.Addr().AsSlice()), + Src: net.IP(cidr.Addr().AsSlice()), Protocol: unix.RTPROT_KERNEL, Table: unix.RT_TABLE_MAIN, Type: unix.RTN_UNICAST, } err := netlink.RouteReplace(&nr) if err != nil { - return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + t.l.WithError(err).WithField("cidr", cidr).Warn("Failed to set default route MTU, retrying") + //retry twice more -- on some systems there appears to be a race condition where if we set routes too soon, netlink says `invalid argument` + for i := 0; i < 2; i++ { + time.Sleep(100 * time.Millisecond) + err = netlink.RouteReplace(&nr) + if err == nil { + break + } else { + t.l.WithError(err).WithField("cidr", cidr).WithField("mtu", t.DefaultMTU).Warn("Failed to set default route MTU, retrying") + } + } + if err != nil { + return fmt.Errorf("failed to set mtu %v on the default route %v; %v", t.DefaultMTU, dr, err) + } } return nil @@ -463,10 +509,6 @@ func (t *tun) removeRoutes(routes []Route) { } } -func (t *tun) Cidr() netip.Prefix { - return t.cidr -} - func (t *tun) Name() string { return t.Device } @@ -515,7 +557,6 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { return } - //TODO: IPV6-WORK what if not ok? gwAddr, ok := netip.AddrFromSlice(r.Gw) if !ok { t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address") @@ -523,15 +564,16 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) { } gwAddr = gwAddr.Unmap() - if !t.cidr.Contains(gwAddr) { - // Gateway isn't in our overlay network, ignore - t.l.WithField("route", r).Debug("Ignoring route update, not in our network") - return + withinNetworks := false + for i := range t.vpnNetworks { + if t.vpnNetworks[i].Contains(gwAddr) { + withinNetworks = true + break + } } - - if x := r.Dst.IP.To4(); x == nil { - // Nebula only handles ipv4 on the overlay currently - t.l.WithField("route", r).Debug("Ignoring route update, destination is not ipv4") + if !withinNetworks { + // Gateway isn't in our overlay network, ignore + t.l.WithField("route", r).Debug("Ignoring route update, not in our networks") return } @@ -563,11 +605,11 @@ func (t *tun) Close() error { } if t.ReadWriteCloser != nil { - t.ReadWriteCloser.Close() + _ = t.ReadWriteCloser.Close() } if t.ioctlFd > 0 { - os.NewFile(t.ioctlFd, "ioctlFd").Close() + _ = os.NewFile(t.ioctlFd, "ioctlFd").Close() } return nil diff --git a/overlay/tun_netbsd.go b/overlay/tun_netbsd.go index 24ab24f78..54984910e 100644 --- a/overlay/tun_netbsd.go +++ b/overlay/tun_netbsd.go @@ -27,12 +27,12 @@ type ifreqDestroy struct { } type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser } @@ -58,13 +58,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in NetBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { // Try to open tun device var file *os.File var err error @@ -84,7 +84,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -104,17 +104,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err return t, nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-net", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-net", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -130,8 +130,18 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -172,8 +182,8 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr { return r } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { @@ -192,7 +202,7 @@ func (t *tun) addRoutes(logErrors bool) error { continue } - cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/route", "-n", "add", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -213,7 +223,8 @@ func (t *tun) removeRoutes(routes []Route) error { continue } - cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-net", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") diff --git a/overlay/tun_openbsd.go b/overlay/tun_openbsd.go index 6463ccbba..dcba05478 100644 --- a/overlay/tun_openbsd.go +++ b/overlay/tun_openbsd.go @@ -21,12 +21,12 @@ import ( ) type tun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger io.ReadWriteCloser @@ -42,13 +42,13 @@ func (t *tun) Close() error { return nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*tun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*tun, error) { return nil, fmt.Errorf("newTunFromFd not supported in OpenBSD") } var deviceNameRE = regexp.MustCompile(`^tun[0-9]+$`) -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, error) { +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*tun, error) { deviceName := c.GetString("tun.dev", "") if deviceName == "" { return nil, fmt.Errorf("a device name in the format of tunN must be specified") @@ -66,7 +66,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err t := &tun{ ReadWriteCloser: file, Device: deviceName, - cidr: cidr, + vpnNetworks: vpnNetworks, MTU: c.GetInt("tun.mtu", DefaultMTU), l: l, } @@ -87,7 +87,7 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*tun, err } func (t *tun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { return err } @@ -123,10 +123,10 @@ func (t *tun) reload(c *config.C, initial bool) error { return nil } -func (t *tun) Activate() error { +func (t *tun) addIp(cidr netip.Prefix) error { var err error // TODO use syscalls instead of exec.Command - cmd := exec.Command("/sbin/ifconfig", t.Device, t.cidr.String(), t.cidr.Addr().String()) + cmd := exec.Command("/sbin/ifconfig", t.Device, cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'ifconfig': %s", err) @@ -138,7 +138,7 @@ func (t *tun) Activate() error { return fmt.Errorf("failed to run 'ifconfig': %s", err) } - cmd = exec.Command("/sbin/route", "-n", "add", "-inet", t.cidr.String(), t.cidr.Addr().String()) + cmd = exec.Command("/sbin/route", "-n", "add", "-inet", cidr.String(), cidr.Addr().String()) t.l.Debug("command: ", cmd.String()) if err = cmd.Run(); err != nil { return fmt.Errorf("failed to run 'route add': %s", err) @@ -148,6 +148,16 @@ func (t *tun) Activate() error { return t.addRoutes(false) } +func (t *tun) Activate() error { + for i := range t.vpnNetworks { + err := t.addIp(t.vpnNetworks[i]) + if err != nil { + return err + } + } + return nil +} + func (t *tun) RouteFor(ip netip.Addr) netip.Addr { r, _ := t.routeTree.Load().Lookup(ip) return r @@ -160,8 +170,8 @@ func (t *tun) addRoutes(logErrors bool) error { // We don't allow route MTUs so only install routes with a via continue } - - cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "add", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { retErr := util.NewContextualError("failed to run 'route add' for unsafe_route", map[string]interface{}{"route": r}, err) @@ -181,8 +191,8 @@ func (t *tun) removeRoutes(routes []Route) error { if !r.Install { continue } - - cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.cidr.Addr().String()) + //todo is this right? + cmd := exec.Command("/sbin/route", "-n", "delete", "-inet", r.Cidr.String(), t.vpnNetworks[0].Addr().String()) t.l.Debug("command: ", cmd.String()) if err := cmd.Run(); err != nil { t.l.WithError(err).WithField("route", r).Error("Failed to remove route") @@ -193,8 +203,8 @@ func (t *tun) removeRoutes(routes []Route) error { return nil } -func (t *tun) Cidr() netip.Prefix { - return t.cidr +func (t *tun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *tun) Name() string { diff --git a/overlay/tun_tester.go b/overlay/tun_tester.go index ba15723a1..cc3942f41 100644 --- a/overlay/tun_tester.go +++ b/overlay/tun_tester.go @@ -16,19 +16,19 @@ import ( ) type TestTun struct { - Device string - cidr netip.Prefix - Routes []Route - routeTree *bart.Table[netip.Addr] - l *logrus.Logger + Device string + vpnNetworks []netip.Prefix + Routes []Route + routeTree *bart.Table[netip.Addr] + l *logrus.Logger closed atomic.Bool rxPackets chan []byte // Packets to receive into nebula TxPackets chan []byte // Packets transmitted outside by nebula } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, error) { - _, routes, err := getAllRoutesFromConfig(c, cidr, true) +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*TestTun, error) { + _, routes, err := getAllRoutesFromConfig(c, vpnNetworks, true) if err != nil { return nil, err } @@ -38,17 +38,17 @@ func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*TestTun, } return &TestTun{ - Device: c.GetString("tun.dev", ""), - cidr: cidr, - Routes: routes, - routeTree: routeTree, - l: l, - rxPackets: make(chan []byte, 10), - TxPackets: make(chan []byte, 10), + Device: c.GetString("tun.dev", ""), + vpnNetworks: vpnNetworks, + Routes: routes, + routeTree: routeTree, + l: l, + rxPackets: make(chan []byte, 10), + TxPackets: make(chan []byte, 10), }, nil } -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (*TestTun, error) { +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (*TestTun, error) { return nil, fmt.Errorf("newTunFromFd not supported") } @@ -95,8 +95,8 @@ func (t *TestTun) Activate() error { return nil } -func (t *TestTun) Cidr() netip.Prefix { - return t.cidr +func (t *TestTun) Networks() []netip.Prefix { + return t.vpnNetworks } func (t *TestTun) Name() string { diff --git a/overlay/tun_water_windows.go b/overlay/tun_water_windows.go deleted file mode 100644 index d78f564cf..000000000 --- a/overlay/tun_water_windows.go +++ /dev/null @@ -1,208 +0,0 @@ -package overlay - -import ( - "fmt" - "io" - "net" - "net/netip" - "os/exec" - "strconv" - "sync/atomic" - - "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" - "github.com/songgao/water" -) - -type waterTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - f *net.Interface - *water.Interface -} - -func newWaterTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*waterTun, error) { - // NOTE: You cannot set the deviceName under Windows, so you must check tun.Device after calling .Activate() - t := &waterTun{ - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err := t.reload(c, true) - if err != nil { - return nil, err - } - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *waterTun) Activate() error { - var err error - t.Interface, err = water.New(water.Config{ - DeviceType: water.TUN, - PlatformSpecificParams: water.PlatformSpecificParams{ - ComponentID: "tap0901", - Network: t.cidr.String(), - }, - }) - if err != nil { - return fmt.Errorf("activate failed: %v", err) - } - - t.Device = t.Interface.Name() - - // TODO use syscalls instead of exec.Command - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", t.Device), - "source=static", - fmt.Sprintf("addr=%s", t.cidr.Addr()), - fmt.Sprintf("mask=%s", net.CIDRMask(t.cidr.Bits(), t.cidr.Addr().BitLen())), - "gateway=none", - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set address: %s", err) - } - err = exec.Command( - `C:\Windows\System32\netsh.exe`, "interface", "ipv4", "set", "interface", - t.Device, - fmt.Sprintf("mtu=%d", t.MTU), - ).Run() - if err != nil { - return fmt.Errorf("failed to run 'netsh' to set MTU: %s", err) - } - - t.f, err = net.InterfaceByName(t.Device) - if err != nil { - return fmt.Errorf("failed to find interface named %s: %v", t.Device, err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *waterTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to set routes", err, t.l) - } else { - for _, r := range findRemovedRoutes(routes, *oldRoutes) { - t.l.WithField("route", r).Info("Removed route") - } - } - } - - return nil -} - -func (t *waterTun) addRoutes(logErrors bool) error { - // Path routes - routes := *t.Routes.Load() - for _, r := range routes { - if !r.Via.IsValid() || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "add", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - } - - return nil -} - -func (t *waterTun) removeRoutes(routes []Route) { - for _, r := range routes { - if !r.Install { - continue - } - - err := exec.Command( - "C:\\Windows\\System32\\route.exe", "delete", r.Cidr.String(), r.Via.String(), "IF", strconv.Itoa(t.f.Index), "METRIC", strconv.Itoa(r.Metric), - ).Run() - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } -} - -func (t *waterTun) RouteFor(ip netip.Addr) netip.Addr { - r, _ := t.routeTree.Load().Lookup(ip) - return r -} - -func (t *waterTun) Cidr() netip.Prefix { - return t.cidr -} - -func (t *waterTun) Name() string { - return t.Device -} - -func (t *waterTun) Close() error { - if t.Interface == nil { - return nil - } - - return t.Interface.Close() -} - -func (t *waterTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} diff --git a/overlay/tun_windows.go b/overlay/tun_windows.go index 3d883093c..289999d8a 100644 --- a/overlay/tun_windows.go +++ b/overlay/tun_windows.go @@ -4,41 +4,268 @@ package overlay import ( + "crypto" "fmt" + "io" "net/netip" "os" "path/filepath" "runtime" + "sync/atomic" "syscall" + "unsafe" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/util" + "github.com/slackhq/nebula/wintun" + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" ) -func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ netip.Prefix) (Device, error) { +const tunGUIDLabel = "Fixed Nebula Windows GUID v1" + +type winTun struct { + Device string + vpnNetworks []netip.Prefix + MTU int + Routes atomic.Pointer[[]Route] + routeTree atomic.Pointer[bart.Table[netip.Addr]] + l *logrus.Logger + + tun *wintun.NativeTun +} + +func newTunFromFd(_ *config.C, _ *logrus.Logger, _ int, _ []netip.Prefix) (Device, error) { return nil, fmt.Errorf("newTunFromFd not supported in Windows") } -func newTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, multiqueue bool) (Device, error) { - useWintun := true - if err := checkWinTunExists(); err != nil { - l.WithError(err).Warn("Check Wintun driver failed, fallback to wintap driver") - useWintun = false +func newTun(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, _ bool) (*winTun, error) { + err := checkWinTunExists() + if err != nil { + return nil, fmt.Errorf("can not load the wintun driver: %w", err) + } + + deviceName := c.GetString("tun.dev", "") + guid, err := generateGUIDByDeviceName(deviceName) + if err != nil { + return nil, fmt.Errorf("generate GUID failed: %w", err) + } + + t := &winTun{ + Device: deviceName, + vpnNetworks: vpnNetworks, + MTU: c.GetInt("tun.mtu", DefaultMTU), + l: l, + } + + err = t.reload(c, true) + if err != nil { + return nil, err } - if useWintun { - device, err := newWinTun(c, l, cidr, multiqueue) + var tunDevice wintun.Device + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) + if err != nil { + // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. + // Trying a second time resolves the issue. + l.WithError(err).Debug("Failed to create wintun device, retrying") + tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) if err != nil { - return nil, fmt.Errorf("create Wintun interface failed, %w", err) + return nil, fmt.Errorf("create TUN device failed: %w", err) } - return device, nil } + t.tun = tunDevice.(*wintun.NativeTun) + + c.RegisterReloadCallback(func(c *config.C) { + err := t.reload(c, false) + if err != nil { + util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) + } + }) + + return t, nil +} - device, err := newWaterTun(c, l, cidr, multiqueue) +func (t *winTun) reload(c *config.C, initial bool) error { + change, routes, err := getAllRoutesFromConfig(c, t.vpnNetworks, initial) if err != nil { - return nil, fmt.Errorf("create wintap driver failed, %w", err) + return err + } + + if !initial && !change { + return nil } - return device, nil + + routeTree, err := makeRouteTree(t.l, routes, false) + if err != nil { + return err + } + + // Teach nebula how to handle the routes before establishing them in the system table + oldRoutes := t.Routes.Swap(&routes) + t.routeTree.Store(routeTree) + + if !initial { + // Remove first, if the system removes a wanted route hopefully it will be re-added next + err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) + if err != nil { + util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) + } + + // Ensure any routes we actually want are installed + err = t.addRoutes(true) + if err != nil { + // Catch any stray logs + util.LogWithContextIfNeeded("Failed to add routes", err, t.l) + } + } + + return nil +} + +func (t *winTun) Activate() error { + luid := winipcfg.LUID(t.tun.LUID()) + + err := luid.SetIPAddresses(t.vpnNetworks) + if err != nil { + return fmt.Errorf("failed to set address: %w", err) + } + + err = t.addRoutes(false) + if err != nil { + return err + } + + return nil +} + +func (t *winTun) addRoutes(logErrors bool) error { + luid := winipcfg.LUID(t.tun.LUID()) + routes := *t.Routes.Load() + foundDefault4 := false + + for _, r := range routes { + if !r.Via.IsValid() || !r.Install { + // We don't allow route MTUs so only install routes with a via + continue + } + + // Add our unsafe route + err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) + if err != nil { + retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) + if logErrors { + retErr.Log(t.l) + continue + } else { + return retErr + } + } else { + t.l.WithField("route", r).Info("Added route") + } + + if !foundDefault4 { + if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { + foundDefault4 = true + } + } + } + + ipif, err := luid.IPInterface(windows.AF_INET) + if err != nil { + return fmt.Errorf("failed to get ip interface: %w", err) + } + + ipif.NLMTU = uint32(t.MTU) + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + + if err := ipif.Set(); err != nil { + return fmt.Errorf("failed to set ip interface: %w", err) + } + return nil +} + +func (t *winTun) removeRoutes(routes []Route) error { + luid := winipcfg.LUID(t.tun.LUID()) + + for _, r := range routes { + if !r.Install { + continue + } + + err := luid.DeleteRoute(r.Cidr, r.Via) + if err != nil { + t.l.WithError(err).WithField("route", r).Error("Failed to remove route") + } else { + t.l.WithField("route", r).Info("Removed route") + } + } + return nil +} + +func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { + r, _ := t.routeTree.Load().Lookup(ip) + return r +} + +func (t *winTun) Networks() []netip.Prefix { + return t.vpnNetworks +} + +func (t *winTun) Name() string { + return t.Device +} + +func (t *winTun) Read(b []byte) (int, error) { + return t.tun.Read(b, 0) +} + +func (t *winTun) Write(b []byte) (int, error) { + return t.tun.Write(b, 0) +} + +func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { + return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") +} + +func (t *winTun) Close() error { + // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, + // so to be certain, just remove everything before destroying. + luid := winipcfg.LUID(t.tun.LUID()) + _ = luid.FlushRoutes(windows.AF_INET) + _ = luid.FlushIPAddresses(windows.AF_INET) + + _ = luid.FlushRoutes(windows.AF_INET6) + _ = luid.FlushIPAddresses(windows.AF_INET6) + + _ = luid.FlushDNS(windows.AF_INET) + _ = luid.FlushDNS(windows.AF_INET6) + + return t.tun.Close() +} + +func generateGUIDByDeviceName(name string) (*windows.GUID, error) { + // GUID is 128 bit + hash := crypto.MD5.New() + + _, err := hash.Write([]byte(tunGUIDLabel)) + if err != nil { + return nil, err + } + + _, err = hash.Write([]byte(name)) + if err != nil { + return nil, err + } + + sum := hash.Sum(nil) + + return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil } func checkWinTunExists() error { diff --git a/overlay/tun_wintun_windows.go b/overlay/tun_wintun_windows.go deleted file mode 100644 index d0103879a..000000000 --- a/overlay/tun_wintun_windows.go +++ /dev/null @@ -1,252 +0,0 @@ -package overlay - -import ( - "crypto" - "fmt" - "io" - "net/netip" - "sync/atomic" - "unsafe" - - "github.com/gaissmai/bart" - "github.com/sirupsen/logrus" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/util" - "github.com/slackhq/nebula/wintun" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" -) - -const tunGUIDLabel = "Fixed Nebula Windows GUID v1" - -type winTun struct { - Device string - cidr netip.Prefix - MTU int - Routes atomic.Pointer[[]Route] - routeTree atomic.Pointer[bart.Table[netip.Addr]] - l *logrus.Logger - - tun *wintun.NativeTun -} - -func generateGUIDByDeviceName(name string) (*windows.GUID, error) { - // GUID is 128 bit - hash := crypto.MD5.New() - - _, err := hash.Write([]byte(tunGUIDLabel)) - if err != nil { - return nil, err - } - - _, err = hash.Write([]byte(name)) - if err != nil { - return nil, err - } - - sum := hash.Sum(nil) - - return (*windows.GUID)(unsafe.Pointer(&sum[0])), nil -} - -func newWinTun(c *config.C, l *logrus.Logger, cidr netip.Prefix, _ bool) (*winTun, error) { - deviceName := c.GetString("tun.dev", "") - guid, err := generateGUIDByDeviceName(deviceName) - if err != nil { - return nil, fmt.Errorf("generate GUID failed: %w", err) - } - - t := &winTun{ - Device: deviceName, - cidr: cidr, - MTU: c.GetInt("tun.mtu", DefaultMTU), - l: l, - } - - err = t.reload(c, true) - if err != nil { - return nil, err - } - - var tunDevice wintun.Device - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - // Windows 10 has an issue with unclean shutdowns not fully cleaning up the wintun device. - // Trying a second time resolves the issue. - l.WithError(err).Debug("Failed to create wintun device, retrying") - tunDevice, err = wintun.CreateTUNWithRequestedGUID(deviceName, guid, t.MTU) - if err != nil { - return nil, fmt.Errorf("create TUN device failed: %w", err) - } - } - t.tun = tunDevice.(*wintun.NativeTun) - - c.RegisterReloadCallback(func(c *config.C) { - err := t.reload(c, false) - if err != nil { - util.LogWithContextIfNeeded("failed to reload tun device", err, t.l) - } - }) - - return t, nil -} - -func (t *winTun) reload(c *config.C, initial bool) error { - change, routes, err := getAllRoutesFromConfig(c, t.cidr, initial) - if err != nil { - return err - } - - if !initial && !change { - return nil - } - - routeTree, err := makeRouteTree(t.l, routes, false) - if err != nil { - return err - } - - // Teach nebula how to handle the routes before establishing them in the system table - oldRoutes := t.Routes.Swap(&routes) - t.routeTree.Store(routeTree) - - if !initial { - // Remove first, if the system removes a wanted route hopefully it will be re-added next - err := t.removeRoutes(findRemovedRoutes(routes, *oldRoutes)) - if err != nil { - util.LogWithContextIfNeeded("Failed to remove routes", err, t.l) - } - - // Ensure any routes we actually want are installed - err = t.addRoutes(true) - if err != nil { - // Catch any stray logs - util.LogWithContextIfNeeded("Failed to add routes", err, t.l) - } - } - - return nil -} - -func (t *winTun) Activate() error { - luid := winipcfg.LUID(t.tun.LUID()) - - err := luid.SetIPAddresses([]netip.Prefix{t.cidr}) - if err != nil { - return fmt.Errorf("failed to set address: %w", err) - } - - err = t.addRoutes(false) - if err != nil { - return err - } - - return nil -} - -func (t *winTun) addRoutes(logErrors bool) error { - luid := winipcfg.LUID(t.tun.LUID()) - routes := *t.Routes.Load() - foundDefault4 := false - - for _, r := range routes { - if !r.Via.IsValid() || !r.Install { - // We don't allow route MTUs so only install routes with a via - continue - } - - // Add our unsafe route - err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric)) - if err != nil { - retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err) - if logErrors { - retErr.Log(t.l) - continue - } else { - return retErr - } - } else { - t.l.WithField("route", r).Info("Added route") - } - - if !foundDefault4 { - if r.Cidr.Bits() == 0 && r.Cidr.Addr().BitLen() == 32 { - foundDefault4 = true - } - } - } - - ipif, err := luid.IPInterface(windows.AF_INET) - if err != nil { - return fmt.Errorf("failed to get ip interface: %w", err) - } - - ipif.NLMTU = uint32(t.MTU) - if foundDefault4 { - ipif.UseAutomaticMetric = false - ipif.Metric = 0 - } - - if err := ipif.Set(); err != nil { - return fmt.Errorf("failed to set ip interface: %w", err) - } - return nil -} - -func (t *winTun) removeRoutes(routes []Route) error { - luid := winipcfg.LUID(t.tun.LUID()) - - for _, r := range routes { - if !r.Install { - continue - } - - err := luid.DeleteRoute(r.Cidr, r.Via) - if err != nil { - t.l.WithError(err).WithField("route", r).Error("Failed to remove route") - } else { - t.l.WithField("route", r).Info("Removed route") - } - } - return nil -} - -func (t *winTun) RouteFor(ip netip.Addr) netip.Addr { - r, _ := t.routeTree.Load().Lookup(ip) - return r -} - -func (t *winTun) Cidr() netip.Prefix { - return t.cidr -} - -func (t *winTun) Name() string { - return t.Device -} - -func (t *winTun) Read(b []byte) (int, error) { - return t.tun.Read(b, 0) -} - -func (t *winTun) Write(b []byte) (int, error) { - return t.tun.Write(b, 0) -} - -func (t *winTun) NewMultiQueueReader() (io.ReadWriteCloser, error) { - return nil, fmt.Errorf("TODO: multiqueue not implemented for windows") -} - -func (t *winTun) Close() error { - // It seems that the Windows networking stack doesn't like it when we destroy interfaces that have active routes, - // so to be certain, just remove everything before destroying. - luid := winipcfg.LUID(t.tun.LUID()) - _ = luid.FlushRoutes(windows.AF_INET) - _ = luid.FlushIPAddresses(windows.AF_INET) - /* We don't support IPV6 yet - _ = luid.FlushRoutes(windows.AF_INET6) - _ = luid.FlushIPAddresses(windows.AF_INET6) - */ - _ = luid.FlushDNS(windows.AF_INET) - - return t.tun.Close() -} diff --git a/overlay/user.go b/overlay/user.go index 1bb4ef5f7..ae665f3a7 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -8,16 +8,16 @@ import ( "github.com/slackhq/nebula/config" ) -func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return NewUserDevice(tunCidr) +func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) { + return NewUserDevice(vpnNetworks) } -func NewUserDevice(tunCidr netip.Prefix) (Device, error) { +func NewUserDevice(vpnNetworks []netip.Prefix) (Device, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, + vpnNetworks: vpnNetworks, outboundReader: or, outboundWriter: ow, inboundReader: ir, @@ -26,7 +26,7 @@ func NewUserDevice(tunCidr netip.Prefix) (Device, error) { } type UserDevice struct { - tunCidr netip.Prefix + vpnNetworks []netip.Prefix outboundReader *io.PipeReader outboundWriter *io.PipeWriter @@ -38,7 +38,7 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks } func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { diff --git a/pki.go b/pki.go index fe64ea5ee..c7c93ac5b 100644 --- a/pki.go +++ b/pki.go @@ -1,13 +1,19 @@ package nebula import ( + "encoding/binary" + "encoding/json" "errors" "fmt" + "net" + "net/netip" "os" + "slices" "strings" "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" @@ -21,12 +27,22 @@ type PKI struct { } type CertState struct { - Certificate cert.Certificate - RawCertificate []byte - RawCertificateNoKey []byte - PublicKey []byte - PrivateKey []byte - pkcs11Backed bool + v1Cert cert.Certificate + v1HandshakeBytes []byte + + v2Cert cert.Certificate + v2HandshakeBytes []byte + + defaultVersion cert.Version + privateKey []byte + pkcs11Backed bool + cipher string + + myVpnNetworks []netip.Prefix + myVpnNetworksTable *bart.Table[struct{}] + myVpnAddrs []netip.Addr + myVpnAddrsTable *bart.Table[struct{}] + myVpnBroadcastAddrsTable *bart.Table[struct{}] } func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { @@ -46,16 +62,16 @@ func NewPKIFromConfig(l *logrus.Logger, c *config.C) (*PKI, error) { return pki, nil } -func (p *PKI) GetCertState() *CertState { - return p.cs.Load() -} - func (p *PKI) GetCAPool() *cert.CAPool { return p.caPool.Load() } +func (p *PKI) getCertState() *CertState { + return p.cs.Load() +} + func (p *PKI) reload(c *config.C, initial bool) error { - err := p.reloadCert(c, initial) + err := p.reloadCerts(c, initial) if err != nil { if initial { return err @@ -74,33 +90,94 @@ func (p *PKI) reload(c *config.C, initial bool) error { return nil } -func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { - cs, err := newCertStateFromConfig(c) +func (p *PKI) reloadCerts(c *config.C, initial bool) *util.ContextualError { + newState, err := newCertStateFromConfig(c) if err != nil { return util.NewContextualError("Could not load client cert", nil, err) } if !initial { - //TODO: include check for mask equality as well + currentState := p.cs.Load() + if newState.v1Cert != nil { + if currentState.v1Cert == nil { + return util.NewContextualError("v1 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v1Cert.Networks(), newState.v1Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v1Cert.Networks(), "old_networks": currentState.v1Cert.Networks()}, + nil, + ) + } + + if currentState.v1Cert.Curve() != newState.v1Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v1Cert.Curve(), "old_curve": currentState.v1Cert.Curve()}, + nil, + ) + } + + } else if currentState.v1Cert != nil { + //TODO: we should be able to tear this down + return util.NewContextualError("v1 certificate was removed, restart required", nil, err) + } + + if newState.v2Cert != nil { + if currentState.v2Cert == nil { + return util.NewContextualError("v2 certificate was added, restart required", nil, err) + } + + // did IP in cert change? if so, don't set + if !slices.Equal(currentState.v2Cert.Networks(), newState.v2Cert.Networks()) { + return util.NewContextualError( + "Networks in new cert was different from old", + m{"new_networks": newState.v2Cert.Networks(), "old_networks": currentState.v2Cert.Networks()}, + nil, + ) + } - // did IP in cert change? if so, don't set - currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Networks() - newIPs := cs.Certificate.Networks() - if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { + if currentState.v2Cert.Curve() != newState.v2Cert.Curve() { + return util.NewContextualError( + "Curve in new cert was different from old", + m{"new_curve": newState.v2Cert.Curve(), "old_curve": currentState.v2Cert.Curve()}, + nil, + ) + } + + } else if currentState.v2Cert != nil { + return util.NewContextualError("v2 certificate was removed, restart required", nil, err) + } + + // Cipher cant be hot swapped so just leave it at what it was before + newState.cipher = currentState.cipher + + } else { + newState.cipher = c.GetString("cipher", "aes") + //TODO: this sucks and we should make it not a global + switch newState.cipher { + case "aes": + noiseEndianness = binary.BigEndian + case "chachapoly": + noiseEndianness = binary.LittleEndian + default: return util.NewContextualError( - "Networks in new cert was different from old", - m{"new_network": newIPs[0], "old_network": oldIPs[0]}, + "unknown cipher", + m{"cipher": newState.cipher}, nil, ) } } - p.cs.Store(cs) + p.cs.Store(newState) + + //TODO: newState needs a stringer that does json if initial { - p.l.WithField("cert", cs.Certificate).Debug("Client nebula certificate") + p.l.WithField("cert", newState).Debug("Client nebula certificate(s)") } else { - p.l.WithField("cert", cs.Certificate).Info("Client cert refreshed from disk") + p.l.WithField("cert", newState).Info("Client certificate(s) refreshed from disk") } return nil } @@ -116,55 +193,65 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate cert.Certificate, pkcs11backed bool, privateKey []byte) (*CertState, error) { - // Marshal the certificate to ensure it is valid - rawCertificate, err := certificate.Marshal() - if err != nil { - return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) +func (cs *CertState) GetDefaultCertificate() cert.Certificate { + c := cs.getCertificate(cs.defaultVersion) + if c == nil { + panic("No default certificate found") } + return c +} - publicKey := certificate.PublicKey() - cs := &CertState{ - RawCertificate: rawCertificate, - Certificate: certificate, - PrivateKey: privateKey, - PublicKey: publicKey, - pkcs11Backed: pkcs11backed, +func (cs *CertState) getCertificate(v cert.Version) cert.Certificate { + switch v { + case cert.Version1: + return cs.v1Cert + case cert.Version2: + return cs.v2Cert } - rawCertNoKey, err := cs.Certificate.MarshalForHandshakes() - if err != nil { - return nil, fmt.Errorf("error marshalling certificate no key: %s", err) + return nil +} + +// getHandshakeBytes returns the cached bytes to be used in a handshake message for the requested version. +// Callers must check if the return []byte is nil. +func (cs *CertState) getHandshakeBytes(v cert.Version) []byte { + switch v { + case cert.Version1: + return cs.v1HandshakeBytes + case cert.Version2: + return cs.v2HandshakeBytes + default: + return nil } - cs.RawCertificateNoKey = rawCertNoKey +} - return cs, nil +func (cs *CertState) String() string { + b, err := cs.MarshalJSON() + if err != nil { + return fmt.Sprintf("error marshaling certificate state: %v", err) + } + return string(b) } -func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { - var pemPrivateKey []byte - if strings.Contains(privPathOrPEM, "-----BEGIN") { - pemPrivateKey = []byte(privPathOrPEM) - privPathOrPEM = "" - rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) +func (cs *CertState) MarshalJSON() ([]byte, error) { + msg := []json.RawMessage{} + if cs.v1Cert != nil { + b, err := cs.v1Cert.MarshalJSON() if err != nil { - return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + return nil, err } - } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { - rawKey = []byte(privPathOrPEM) - return rawKey, cert.Curve_P256, true, nil - } else { - pemPrivateKey, err = os.ReadFile(privPathOrPEM) + msg = append(msg, b) + } + + if cs.v2Cert != nil { + b, err := cs.v2Cert.MarshalJSON() if err != nil { - return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) - } - rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) - if err != nil { - return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + return nil, err } + msg = append(msg, b) } - return + return json.Marshal(msg) } func newCertStateFromConfig(c *config.C) (*CertState, error) { @@ -198,24 +285,197 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { } } - nebulaCert, _, err := cert.UnmarshalCertificateFromPEM(rawCert) + var crt, v1, v2 cert.Certificate + for { + // Load the certificate + crt, rawCert, err = loadCertificate(rawCert) + if err != nil { + return nil, err + } + + switch crt.Version() { + case cert.Version1: + if v1 != nil { + return nil, fmt.Errorf("v1 certificate already found in pki.cert") + } + v1 = crt + case cert.Version2: + if v2 != nil { + return nil, fmt.Errorf("v2 certificate already found in pki.cert") + } + v2 = crt + default: + return nil, fmt.Errorf("unknown certificate version %v", crt.Version()) + } + + if len(rawCert) == 0 || strings.TrimSpace(string(rawCert)) == "" { + break + } + } + + if v1 == nil && v2 == nil { + return nil, errors.New("no certificates found in pki.cert") + } + + useDefaultVersion := uint32(1) + if v1 == nil { + // The only condition that requires v2 as the default is if only a v2 certificate is present + // We do this to avoid having to configure it specifically in the config file + useDefaultVersion = 2 + } + + rawDefaultVersion := c.GetUint32("pki.default_version", useDefaultVersion) + var defaultVersion cert.Version + switch rawDefaultVersion { + case 1: + if v1 == nil { + return nil, fmt.Errorf("can not use pki.default_version 1 without a v1 certificate in pki.cert") + } + defaultVersion = cert.Version1 + case 2: + defaultVersion = cert.Version2 + default: + return nil, fmt.Errorf("unknown pki.default_version: %v", rawDefaultVersion) + } + + return newCertState(defaultVersion, v1, v2, isPkcs11, curve, rawKey) +} + +func newCertState(dv cert.Version, v1, v2 cert.Certificate, pkcs11backed bool, privateKeyCurve cert.Curve, privateKey []byte) (*CertState, error) { + cs := CertState{ + privateKey: privateKey, + pkcs11Backed: pkcs11backed, + myVpnNetworksTable: new(bart.Table[struct{}]), + myVpnAddrsTable: new(bart.Table[struct{}]), + myVpnBroadcastAddrsTable: new(bart.Table[struct{}]), + } + + if v1 != nil && v2 != nil { + if !slices.Equal(v1.PublicKey(), v2.PublicKey()) { + return nil, util.NewContextualError("v1 and v2 public keys are not the same, ignoring", nil, nil) + } + + if v1.Curve() != v2.Curve() { + return nil, util.NewContextualError("v1 and v2 curve are not the same, ignoring", nil, nil) + } + + //TODO: make sure v2 has v1s address + + cs.defaultVersion = dv + } + + if v1 != nil { + if pkcs11backed { + //TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v1.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v1hs, err := v1.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v1Cert = v1 + cs.v1HandshakeBytes = v1hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version1 + } + } + + if v2 != nil { + if pkcs11backed { + //TODO: We do not currently have a method to verify a public private key pair when the private key is in an hsm + } else { + if err := v2.VerifyPrivateKey(privateKeyCurve, privateKey); err != nil { + return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + } + } + + v2hs, err := v2.MarshalForHandshakes() + if err != nil { + return nil, fmt.Errorf("error marshalling certificate for handshake: %w", err) + } + cs.v2Cert = v2 + cs.v2HandshakeBytes = v2hs + + if cs.defaultVersion == 0 { + cs.defaultVersion = cert.Version2 + } + } + + var crt cert.Certificate + crt = cs.getCertificate(cert.Version2) + if crt == nil { + // v2 certificates are a superset, only look at v1 if its all we have + crt = cs.getCertificate(cert.Version1) + } + + for _, network := range crt.Networks() { + cs.myVpnNetworks = append(cs.myVpnNetworks, network) + cs.myVpnNetworksTable.Insert(network, struct{}{}) + + cs.myVpnAddrs = append(cs.myVpnAddrs, network.Addr()) + cs.myVpnAddrsTable.Insert(netip.PrefixFrom(network.Addr(), network.Addr().BitLen()), struct{}{}) + + if network.Addr().Is4() { + addr := network.Masked().Addr().As4() + mask := net.CIDRMask(network.Bits(), network.Addr().BitLen()) + binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(mask)) + cs.myVpnBroadcastAddrsTable.Insert(netip.PrefixFrom(netip.AddrFrom4(addr), network.Addr().BitLen()), struct{}{}) + } + } + + return &cs, nil +} + +func loadPrivateKey(privPathOrPEM string) (rawKey []byte, curve cert.Curve, isPkcs11 bool, err error) { + var pemPrivateKey []byte + if strings.Contains(privPathOrPEM, "-----BEGIN") { + pemPrivateKey = []byte(privPathOrPEM) + privPathOrPEM = "" + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } else if strings.HasPrefix(privPathOrPEM, "pkcs11:") { + rawKey = []byte(privPathOrPEM) + return rawKey, cert.Curve_P256, true, nil + } else { + pemPrivateKey, err = os.ReadFile(privPathOrPEM) + if err != nil { + return nil, curve, false, fmt.Errorf("unable to read pki.key file %s: %s", privPathOrPEM, err) + } + rawKey, _, curve, err = cert.UnmarshalPrivateKeyFromPEM(pemPrivateKey) + if err != nil { + return nil, curve, false, fmt.Errorf("error while unmarshaling pki.key %s: %s", privPathOrPEM, err) + } + } + + return +} + +func loadCertificate(b []byte) (cert.Certificate, []byte, error) { + c, b, err := cert.UnmarshalCertificateFromPEM(b) if err != nil { - return nil, fmt.Errorf("error while unmarshaling pki.cert %s: %s", pubPathOrPEM, err) + return nil, b, fmt.Errorf("error while unmarshaling pki.cert: %w", err) } - if nebulaCert.Expired(time.Now()) { - return nil, fmt.Errorf("nebula certificate for this host is expired") + if c.Expired(time.Now()) { + return nil, b, fmt.Errorf("nebula certificate for this host is expired") } - if len(nebulaCert.Networks()) == 0 { - return nil, fmt.Errorf("no networks encoded in certificate") + if len(c.Networks()) == 0 { + return nil, b, fmt.Errorf("no networks encoded in certificate") } - if err = nebulaCert.VerifyPrivateKey(curve, rawKey); err != nil { - return nil, fmt.Errorf("private key is not a pair with public key in nebula cert") + if c.IsCA() { + return nil, b, fmt.Errorf("host certificate is a CA certificate") } - return newCertState(nebulaCert, isPkcs11, rawKey) + return c, b, nil } func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { diff --git a/relay_manager.go b/relay_manager.go index 1a3a4d48f..49e069af6 100644 --- a/relay_manager.go +++ b/relay_manager.go @@ -9,6 +9,7 @@ import ( "sync/atomic" "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/header" ) @@ -72,7 +73,7 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti Type: relayType, State: state, LocalIndex: index, - PeerIp: vpnIp, + PeerAddr: vpnIp, } if remoteIdx != nil { @@ -91,40 +92,59 @@ func AddRelay(l *logrus.Logger, relayHostInfo *HostInfo, hm *HostMap, vpnIp neti func (rm *relayManager) EstablishRelay(relayHostInfo *HostInfo, m *NebulaControl) (*Relay, error) { relay, ok := relayHostInfo.relayState.CompleteRelayByIdx(m.InitiatorRelayIndex, m.ResponderRelayIndex) if !ok { - rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnIp, + //TODO: we need to handle possibly logging deprecated fields as well + rm.l.WithFields(logrus.Fields{"relay": relayHostInfo.vpnAddrs[0], "initiatorRelayIndex": m.InitiatorRelayIndex, - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp}).Info("relayManager failed to update relay") + "relayFrom": m.RelayFromAddr, + "relayTo": m.RelayToAddr}).Info("relayManager failed to update relay") return nil, fmt.Errorf("unknown relay") } return relay, nil } -func (rm *relayManager) HandleControlMsg(h *HostInfo, m *NebulaControl, f *Interface) { +func (rm *relayManager) HandleControlMsg(h *HostInfo, d []byte, f *Interface) { + msg := &NebulaControl{} + err := msg.Unmarshal(d) + if err != nil { + h.logger(f.l).WithError(err).Error("Failed to unmarshal control message") + return + } + + var v cert.Version + if msg.OldRelayFromAddr > 0 || msg.OldRelayToAddr > 0 { + v = cert.Version1 + + //TODO: yeah this is junk but maybe its less junky than the other options + b := [4]byte{} + binary.BigEndian.PutUint32(b[:], msg.OldRelayFromAddr) + msg.RelayFromAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + + binary.BigEndian.PutUint32(b[:], msg.OldRelayToAddr) + msg.RelayToAddr = netAddrToProtoAddr(netip.AddrFrom4(b)) + } else { + v = cert.Version2 + } - switch m.Type { + switch msg.Type { case NebulaControl_CreateRelayRequest: - rm.handleCreateRelayRequest(h, f, m) + rm.handleCreateRelayRequest(v, h, f, msg) case NebulaControl_CreateRelayResponse: - rm.handleCreateRelayResponse(h, f, m) + rm.handleCreateRelayResponse(v, h, f, msg) } - } -func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m *NebulaControl) { +func (rm *relayManager) handleCreateRelayResponse(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { rm.l.WithFields(logrus.Fields{ - "relayFrom": m.RelayFromIp, - "relayTo": m.RelayToIp, + "relayFrom": protoAddrToNetAddr(m.RelayFromAddr), + "relayTo": protoAddrToNetAddr(m.RelayToAddr), "initiatorRelayIndex": m.InitiatorRelayIndex, "responderRelayIndex": m.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("handleCreateRelayResponse") - target := m.RelayToIp - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - targetAddr := netip.AddrFrom4(b) + + target := m.RelayToAddr + targetAddr := protoAddrToNetAddr(target) relay, err := rm.EstablishRelay(h, m) if err != nil { @@ -136,68 +156,83 @@ func (rm *relayManager) handleCreateRelayResponse(h *HostInfo, f *Interface, m * return } // I'm the middle man. Let the initiator know that the I've established the relay they requested. - peerHostInfo := rm.hostmap.QueryVpnIp(relay.PeerIp) + peerHostInfo := rm.hostmap.QueryVpnAddr(relay.PeerAddr) if peerHostInfo == nil { - rm.l.WithField("relayTo", relay.PeerIp).Error("Can't find a HostInfo for peer") + rm.l.WithField("relayTo", relay.PeerAddr).Error("Can't find a HostInfo for peer") return } peerRelay, ok := peerHostInfo.relayState.QueryRelayForByIp(targetAddr) if !ok { - rm.l.WithField("relayTo", peerHostInfo.vpnIp).Error("peerRelay does not have Relay state for relayTo") + rm.l.WithField("relayTo", peerHostInfo.vpnAddrs[0]).Error("peerRelay does not have Relay state for relayTo") return } - if peerRelay.State == PeerRequested { - //TODO: IPV6-WORK - b = peerHostInfo.vpnIp.As4() - peerRelay.State = Established + switch peerRelay.State { + case Requested: + // I initiated the request to this peer, but haven't heard back from the peer yet. I must wait for this peer + // to respond to complete the connection. + case PeerRequested, Disestablished, Established: + peerHostInfo.relayState.UpdateRelayForByIpState(targetAddr, Established) resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: peerRelay.LocalIndex, InitiatorRelayIndex: peerRelay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(b[:]), - RelayToIp: uint32(target), } + + if v == cert.Version1 { + peer := peerHostInfo.vpnAddrs[0] + if !peer.Is4() { + //TODO: log cant do it + return + } + + b := peer.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = targetAddr.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(peerHostInfo.vpnAddrs[0]) + resp.RelayToAddr = target + } + msg, err := resp.Marshal() if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") + rm.l.WithError(err). + Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") } else { f.SendMessageToHostInfo(header.Control, 0, peerHostInfo, msg, make([]byte, 12), make([]byte, mtu)) rm.l.WithFields(logrus.Fields{ - "relayFrom": resp.RelayFromIp, - "relayTo": resp.RelayToIp, + "relayFrom": resp.RelayFromAddr, + "relayTo": resp.RelayToAddr, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": peerHostInfo.vpnIp}). + "vpnAddrs": peerHostInfo.vpnAddrs}). Info("send CreateRelayResponse") } } } -func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *NebulaControl) { - //TODO: IPV6-WORK - b := [4]byte{} - binary.BigEndian.PutUint32(b[:], m.RelayFromIp) - from := netip.AddrFrom4(b) - - binary.BigEndian.PutUint32(b[:], m.RelayToIp) - target := netip.AddrFrom4(b) +func (rm *relayManager) handleCreateRelayRequest(v cert.Version, h *HostInfo, f *Interface, m *NebulaControl) { + from := protoAddrToNetAddr(m.RelayFromAddr) + target := protoAddrToNetAddr(m.RelayToAddr) logMsg := rm.l.WithFields(logrus.Fields{ "relayFrom": from, "relayTo": target, "initiatorRelayIndex": m.InitiatorRelayIndex, - "vpnIp": h.vpnIp}) + "vpnAddrs": h.vpnAddrs}) logMsg.Info("handleCreateRelayRequest") // Is the source of the relay me? This should never happen, but did happen due to // an issue migrating relays over to newly re-handshaked host info objects. - if from == f.myVpnNet.Addr() { + _, found := f.myVpnAddrsTable.Lookup(from) + if found { logMsg.WithField("myIP", from).Error("Discarding relay request from myself") return } + // Is the target of the relay me? - if target == f.myVpnNet.Addr() { + _, found = f.myVpnAddrsTable.Lookup(target) + if found { existingRelay, ok := h.relayState.QueryRelayForByIp(from) if ok { switch existingRelay.State { @@ -215,6 +250,21 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") return } + case Disestablished: + if existingRelay.RemoteIndex != m.InitiatorRelayIndex { + // We got a brand new Relay request, because its index is different than what we saw before. + // This should never happen. The peer should never change an index, once created. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") + return + } + // Mark the relay as 'Established' because it's safe to use again + h.relayState.UpdateRelayForByIpState(from, Established) + case PeerRequested: + // I should never be in this state, because I am terminal, not forwarding. + logMsg.WithFields(logrus.Fields{ + "existingRemoteIndex": existingRelay.RemoteIndex, + "state": existingRelay.State}).Error("Unexpected Relay State found") } } else { _, err := AddRelay(rm.l, h, f.hostMap, from, &m.InitiatorRelayIndex, TerminalType, Established) @@ -226,21 +276,26 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N relay, ok := h.relayState.QueryRelayForByIp(from) if !ok { - logMsg.Error("Relay State not found") + logMsg.WithField("from", from).Error("Relay State not found") return } - //TODO: IPV6-WORK - fromB := from.As4() - targetB := target.As4() - resp := NebulaControl{ Type: NebulaControl_CreateRelayResponse, ResponderRelayIndex: relay.LocalIndex, InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), } + + if v == cert.Version1 { + b := from.As4() + resp.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + resp.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + resp.RelayFromAddr = netAddrToProtoAddr(from) + resp.RelayToAddr = netAddrToProtoAddr(target) + } + msg, err := resp.Marshal() if err != nil { logMsg. @@ -253,7 +308,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N "relayTo": target, "initiatorRelayIndex": resp.InitiatorRelayIndex, "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). + "vpnAddrs": h.vpnAddrs}). Info("send CreateRelayResponse") } return @@ -262,7 +317,7 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N if !rm.GetAmRelay() { return } - peer := rm.hostmap.QueryVpnIp(target) + peer := rm.hostmap.QueryVpnAddr(target) if peer == nil { // Try to establish a connection to this host. If we get a future relay request, // we'll be ready! @@ -273,104 +328,65 @@ func (rm *relayManager) handleCreateRelayRequest(h *HostInfo, f *Interface, m *N // Only create relays to peers for whom I have a direct connection return } - sendCreateRequest := false var index uint32 var err error targetRelay, ok := peer.relayState.QueryRelayForByIp(from) if ok { index = targetRelay.LocalIndex - if targetRelay.State == Requested { - sendCreateRequest = true - } } else { // Allocate an index in the hostMap for this relay peer index, err = AddRelay(rm.l, peer, f.hostMap, from, nil, ForwardingType, Requested) if err != nil { return } - sendCreateRequest = true } - if sendCreateRequest { - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() - - // Send a CreateRelayRequest to the peer. - req := NebulaControl{ - Type: NebulaControl_CreateRelayRequest, - InitiatorRelayIndex: index, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), - } - msg, err := req.Marshal() - if err != nil { - logMsg. - WithError(err).Error("relayManager Failed to marshal Control message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK another lazy used to use the req object - "relayFrom": h.vpnIp, - "relayTo": target, - "initiatorRelayIndex": req.InitiatorRelayIndex, - "responderRelayIndex": req.ResponderRelayIndex, - "vpnIp": target}). - Info("send CreateRelayRequest") + peer.relayState.UpdateRelayForByIpState(from, Requested) + // Send a CreateRelayRequest to the peer. + req := NebulaControl{ + Type: NebulaControl_CreateRelayRequest, + InitiatorRelayIndex: index, + } + + if v == cert.Version1 { + if !h.vpnAddrs[0].Is4() { + //TODO: log it + return } + + b := h.vpnAddrs[0].As4() + req.OldRelayFromAddr = binary.BigEndian.Uint32(b[:]) + b = target.As4() + req.OldRelayToAddr = binary.BigEndian.Uint32(b[:]) + } else { + req.RelayFromAddr = netAddrToProtoAddr(h.vpnAddrs[0]) + req.RelayToAddr = netAddrToProtoAddr(target) } + + msg, err := req.Marshal() + if err != nil { + logMsg. + WithError(err).Error("relayManager Failed to marshal Control message to create relay") + } else { + f.SendMessageToHostInfo(header.Control, 0, peer, msg, make([]byte, 12), make([]byte, mtu)) + rm.l.WithFields(logrus.Fields{ + //TODO: IPV6-WORK another lazy used to use the req object + "relayFrom": h.vpnAddrs[0], + "relayTo": target, + "initiatorRelayIndex": req.InitiatorRelayIndex, + "responderRelayIndex": req.ResponderRelayIndex, + "vpnAddr": target}). + Info("send CreateRelayRequest") + } + // Also track the half-created Relay state just received - relay, ok := h.relayState.QueryRelayForByIp(target) + _, ok = h.relayState.QueryRelayForByIp(target) if !ok { - // Add the relay - state := PeerRequested - if targetRelay != nil && targetRelay.State == Established { - state = Established - } - _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, state) + _, err := AddRelay(rm.l, h, f.hostMap, target, &m.InitiatorRelayIndex, ForwardingType, PeerRequested) if err != nil { logMsg. WithError(err).Error("relayManager Failed to allocate a local index for relay") return } - } else { - switch relay.State { - case Established: - if relay.RemoteIndex != m.InitiatorRelayIndex { - // We got a brand new Relay request, because its index is different than what we saw before. - // This should never happen. The peer should never change an index, once created. - logMsg.WithFields(logrus.Fields{ - "existingRemoteIndex": relay.RemoteIndex}).Error("Existing relay mismatch with CreateRelayRequest") - return - } - //TODO: IPV6-WORK - fromB := h.vpnIp.As4() - targetB := target.As4() - resp := NebulaControl{ - Type: NebulaControl_CreateRelayResponse, - ResponderRelayIndex: relay.LocalIndex, - InitiatorRelayIndex: relay.RemoteIndex, - RelayFromIp: binary.BigEndian.Uint32(fromB[:]), - RelayToIp: binary.BigEndian.Uint32(targetB[:]), - } - msg, err := resp.Marshal() - if err != nil { - rm.l. - WithError(err).Error("relayManager Failed to marshal Control CreateRelayResponse message to create relay") - } else { - f.SendMessageToHostInfo(header.Control, 0, h, msg, make([]byte, 12), make([]byte, mtu)) - rm.l.WithFields(logrus.Fields{ - //TODO: IPV6-WORK more lazy, used to use resp object - "relayFrom": h.vpnIp, - "relayTo": target, - "initiatorRelayIndex": resp.InitiatorRelayIndex, - "responderRelayIndex": resp.ResponderRelayIndex, - "vpnIp": h.vpnIp}). - Info("send CreateRelayResponse") - } - - case Requested: - // Keep waiting for the other relay to complete - } } } } diff --git a/remote_list.go b/remote_list.go index 94db8f2f8..f59dbf2df 100644 --- a/remote_list.go +++ b/remote_list.go @@ -4,6 +4,7 @@ import ( "context" "net" "net/netip" + "slices" "sort" "strconv" "sync" @@ -17,8 +18,8 @@ import ( type forEachFunc func(addr netip.AddrPort, preferred bool) // The checkFuncs here are to simplify bulk importing LH query response logic into a single function (reset slice and iterate) -type checkFuncV4 func(vpnIp netip.Addr, to *Ip4AndPort) bool -type checkFuncV6 func(vpnIp netip.Addr, to *Ip6AndPort) bool +type checkFuncV4 func(vpnIp netip.Addr, to *V4AddrPort) bool +type checkFuncV6 func(vpnIp netip.Addr, to *V6AddrPort) bool // CacheMap is a struct that better represents the lighthouse cache for humans // The string key is the owners vpnIp @@ -48,14 +49,14 @@ type cacheRelay struct { // cacheV4 stores learned and reported ipv4 records under cache type cacheV4 struct { - learned *Ip4AndPort - reported []*Ip4AndPort + learned *V4AddrPort + reported []*V4AddrPort } // cacheV4 stores learned and reported ipv6 records under cache type cacheV6 struct { - learned *Ip6AndPort - reported []*Ip6AndPort + learned *V6AddrPort + reported []*V6AddrPort } type hostnamePort struct { @@ -170,7 +171,7 @@ func (hr *hostnamesResults) Cancel() { } } -func (hr *hostnamesResults) GetIPs() []netip.AddrPort { +func (hr *hostnamesResults) GetAddrs() []netip.AddrPort { var retSlice []netip.AddrPort if hr != nil { p := hr.ips.Load() @@ -189,6 +190,9 @@ type RemoteList struct { // Every interaction with internals requires a lock! sync.RWMutex + // The full list of vpn addresses assigned to this host + vpnAddrs []netip.Addr + // A deduplicated set of addresses. Any accessor should lock beforehand. addrs []netip.AddrPort @@ -212,13 +216,16 @@ type RemoteList struct { } // NewRemoteList creates a new empty RemoteList -func NewRemoteList(shouldAdd func(netip.Addr) bool) *RemoteList { - return &RemoteList{ +func NewRemoteList(vpnAddrs []netip.Addr, shouldAdd func(netip.Addr) bool) *RemoteList { + r := &RemoteList{ + vpnAddrs: make([]netip.Addr, len(vpnAddrs)), addrs: make([]netip.AddrPort, 0), relays: make([]netip.Addr, 0), cache: make(map[netip.Addr]*cache), shouldAdd: shouldAdd, } + copy(r.vpnAddrs, vpnAddrs) + return r } func (r *RemoteList) unlockedSetHostnamesResults(hr *hostnamesResults) { @@ -273,9 +280,9 @@ func (r *RemoteList) LearnRemote(ownerVpnIp netip.Addr, remote netip.AddrPort) { r.Lock() defer r.Unlock() if remote.Addr().Is4() { - r.unlockedSetLearnedV4(ownerVpnIp, NewIp4AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV4(ownerVpnIp, netAddrToProtoV4AddrPort(remote.Addr(), remote.Port())) } else { - r.unlockedSetLearnedV6(ownerVpnIp, NewIp6AndPortFromNetIP(remote.Addr(), remote.Port())) + r.unlockedSetLearnedV6(ownerVpnIp, netAddrToProtoV6AddrPort(remote.Addr(), remote.Port())) } } @@ -304,21 +311,21 @@ func (r *RemoteList) CopyCache() *CacheMap { if mc.v4 != nil { if mc.v4.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp4AndPort(mc.v4.learned)) + c.Learned = append(c.Learned, protoV4AddrPortToNetAddrPort(mc.v4.learned)) } for _, a := range mc.v4.reported { - c.Reported = append(c.Reported, AddrPortFromIp4AndPort(a)) + c.Reported = append(c.Reported, protoV4AddrPortToNetAddrPort(a)) } } if mc.v6 != nil { if mc.v6.learned != nil { - c.Learned = append(c.Learned, AddrPortFromIp6AndPort(mc.v6.learned)) + c.Learned = append(c.Learned, protoV6AddrPortToNetAddrPort(mc.v6.learned)) } for _, a := range mc.v6.reported { - c.Reported = append(c.Reported, AddrPortFromIp6AndPort(a)) + c.Reported = append(c.Reported, protoV6AddrPortToNetAddrPort(a)) } } @@ -401,14 +408,14 @@ func (r *RemoteList) unlockedIsBad(remote netip.AddrPort) bool { // unlockedSetLearnedV4 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedSetLearnedV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV4(ownerVpnIp).learned = to } // unlockedSetV4 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPort, check checkFuncV4) { +func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*V4AddrPort, check checkFuncV4) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) @@ -423,7 +430,7 @@ func (r *RemoteList) unlockedSetV4(ownerVpnIp, vpnIp netip.Addr, to []*Ip4AndPor } } -func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.Addr) { +func (r *RemoteList) unlockedSetRelay(ownerVpnIp netip.Addr, to []netip.Addr) { r.shouldRebuild = true c := r.unlockedGetOrMakeRelay(ownerVpnIp) @@ -436,12 +443,12 @@ func (r *RemoteList) unlockedSetRelay(ownerVpnIp, vpnIp netip.Addr, to []netip.A // unlockedPrependV4 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { +func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *V4AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV4(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip4AndPort{to}, c.reported...) + c.reported = append([]*V4AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -449,14 +456,14 @@ func (r *RemoteList) unlockedPrependV4(ownerVpnIp netip.Addr, to *Ip4AndPort) { // unlockedSetLearnedV6 assumes you have the write lock and sets the current learned address for this owner and marks the // deduplicated address list as dirty -func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedSetLearnedV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true r.unlockedGetOrMakeV6(ownerVpnIp).learned = to } // unlockedSetV6 assumes you have the write lock and resets the reported list of ips for this owner to the list provided // and marks the deduplicated address list as dirty -func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPort, check checkFuncV6) { +func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*V6AddrPort, check checkFuncV6) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) @@ -473,12 +480,12 @@ func (r *RemoteList) unlockedSetV6(ownerVpnIp, vpnIp netip.Addr, to []*Ip6AndPor // unlockedPrependV6 assumes you have the write lock and prepends the address in the reported list for this owner // This is only useful for establishing static hosts -func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *Ip6AndPort) { +func (r *RemoteList) unlockedPrependV6(ownerVpnIp netip.Addr, to *V6AddrPort) { r.shouldRebuild = true c := r.unlockedGetOrMakeV6(ownerVpnIp) // We are doing the easy append because this is rarely called - c.reported = append([]*Ip6AndPort{to}, c.reported...) + c.reported = append([]*V6AddrPort{to}, c.reported...) if len(c.reported) > MaxRemotes { c.reported = c.reported[:MaxRemotes] } @@ -536,14 +543,14 @@ func (r *RemoteList) unlockedCollect() { for _, c := range r.cache { if c.v4 != nil { if c.v4.learned != nil { - u := AddrPortFromIp4AndPort(c.v4.learned) + u := protoV4AddrPortToNetAddrPort(c.v4.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v4.reported { - u := AddrPortFromIp4AndPort(v) + u := protoV4AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -552,14 +559,14 @@ func (r *RemoteList) unlockedCollect() { if c.v6 != nil { if c.v6.learned != nil { - u := AddrPortFromIp6AndPort(c.v6.learned) + u := protoV6AddrPortToNetAddrPort(c.v6.learned) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } } for _, v := range c.v6.reported { - u := AddrPortFromIp6AndPort(v) + u := protoV6AddrPortToNetAddrPort(v) if !r.unlockedIsBad(u) { addrs = append(addrs, u) } @@ -573,7 +580,7 @@ func (r *RemoteList) unlockedCollect() { } } - dnsAddrs := r.hr.GetIPs() + dnsAddrs := r.hr.GetAddrs() for _, addr := range dnsAddrs { if r.shouldAdd == nil || r.shouldAdd(addr.Addr()) { if !r.unlockedIsBad(addr) { @@ -589,6 +596,21 @@ func (r *RemoteList) unlockedCollect() { // unlockedSort assumes you have the write lock and performs the deduping and sorting of the address list func (r *RemoteList) unlockedSort(preferredRanges []netip.Prefix) { + // Use a map to deduplicate any relay addresses + dedupedRelays := map[netip.Addr]struct{}{} + for _, relay := range r.relays { + dedupedRelays[relay] = struct{}{} + } + r.relays = r.relays[:0] + for relay := range dedupedRelays { + r.relays = append(r.relays, relay) + } + // Put them in a somewhat consistent order after de-duplication + slices.SortFunc(r.relays, func(a, b netip.Addr) int { + return a.Compare(b) + }) + + // Now the addrs n := len(r.addrs) if n < 2 { return diff --git a/remote_list_test.go b/remote_list_test.go index 62a892b00..0caf86a4d 100644 --- a/remote_list_test.go +++ b/remote_list_test.go @@ -9,11 +9,11 @@ import ( ) func TestRemoteList_Rebuild(t *testing.T) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), // this is duped newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), // this is duped @@ -25,20 +25,30 @@ func TestRemoteList_Rebuild(t *testing.T) { newIp4AndPortFromString("70.199.182.92:1476"), // almost dupe of 0 with a diff port newIp4AndPortFromString("70.199.182.92:1475"), // this is a dupe }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.1"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), // this is duped newIp6AndPortFromString("[1::1]:2"), // almost dupe of 0 with a diff port, also gets duped newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe newIp6AndPortFromString("[1::1]:2"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, + ) + + rl.unlockedSetRelay( + netip.MustParseAddr("0.0.0.1"), + []netip.Addr{ + netip.MustParseAddr("1::1"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1.2.3.4"), + netip.MustParseAddr("1::1"), + }, ) rl.Rebuild([]netip.Prefix{}) @@ -76,6 +86,11 @@ func TestRemoteList_Rebuild(t *testing.T) { assert.Equal(t, "[1::1]:2", rl.addrs[8].String()) assert.Equal(t, "[1:100::1]:1", rl.addrs[9].String()) + // assert relay deduplicated + assert.Len(t, rl.relays, 2) + assert.Equal(t, "1.2.3.4", rl.relays[0].String()) + assert.Equal(t, "1::1", rl.relays[1].String()) + // Ensure we can hoist a specific ipv4 range over anything else rl.Rebuild([]netip.Prefix{netip.MustParsePrefix("172.17.0.0/16")}) assert.Len(t, rl.addrs, 10, "addrs contains too many entries") @@ -98,11 +113,11 @@ func TestRemoteList_Rebuild(t *testing.T) { } func BenchmarkFullRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -112,19 +127,19 @@ func BenchmarkFullRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -160,11 +175,11 @@ func BenchmarkFullRebuild(b *testing.B) { } func BenchmarkSortRebuild(b *testing.B) { - rl := NewRemoteList(nil) + rl := NewRemoteList([]netip.Addr{netip.MustParseAddr("0.0.0.0")}, nil) rl.unlockedSetV4( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip4AndPort{ + []*V4AddrPort{ newIp4AndPortFromString("70.199.182.92:1475"), newIp4AndPortFromString("172.17.0.182:10101"), newIp4AndPortFromString("172.17.1.1:10101"), @@ -174,19 +189,19 @@ func BenchmarkSortRebuild(b *testing.B) { newIp4AndPortFromString("172.17.1.1:10101"), // this is a dupe newIp4AndPortFromString("70.199.182.92:1476"), // dupe of 0 with a diff port }, - func(netip.Addr, *Ip4AndPort) bool { return true }, + func(netip.Addr, *V4AddrPort) bool { return true }, ) rl.unlockedSetV6( netip.MustParseAddr("0.0.0.0"), netip.MustParseAddr("0.0.0.0"), - []*Ip6AndPort{ + []*V6AddrPort{ newIp6AndPortFromString("[1::1]:1"), newIp6AndPortFromString("[1::1]:2"), // dupe of 0 with a diff port newIp6AndPortFromString("[1:100::1]:1"), newIp6AndPortFromString("[1::1]:1"), // this is a dupe }, - func(netip.Addr, *Ip6AndPort) bool { return true }, + func(netip.Addr, *V6AddrPort) bool { return true }, ) b.Run("no preferred", func(b *testing.B) { @@ -224,19 +239,19 @@ func BenchmarkSortRebuild(b *testing.B) { }) } -func newIp4AndPortFromString(s string) *Ip4AndPort { +func newIp4AndPortFromString(s string) *V4AddrPort { a := netip.MustParseAddrPort(s) v4Addr := a.Addr().As4() - return &Ip4AndPort{ - Ip: binary.BigEndian.Uint32(v4Addr[:]), + return &V4AddrPort{ + Addr: binary.BigEndian.Uint32(v4Addr[:]), Port: uint32(a.Port()), } } -func newIp6AndPortFromString(s string) *Ip6AndPort { +func newIp6AndPortFromString(s string) *V6AddrPort { a := netip.MustParseAddrPort(s) v6Addr := a.Addr().As16() - return &Ip6AndPort{ + return &V6AddrPort{ Hi: binary.BigEndian.Uint64(v6Addr[:8]), Lo: binary.BigEndian.Uint64(v6Addr[8:]), Port: uint32(a.Port()), diff --git a/service/service.go b/service/service.go index 4ddd30182..4339677bd 100644 --- a/service/service.go +++ b/service/service.go @@ -90,9 +90,9 @@ func New(config *config.C) (*Service, error) { }, }) - ipNet := device.Cidr() + ipNet := device.Networks() pa := tcpip.ProtocolAddress{ - AddressWithPrefix: tcpip.AddrFromSlice(ipNet.Addr().AsSlice()).WithPrefix(), + AddressWithPrefix: tcpip.AddrFromSlice(ipNet[0].Addr().AsSlice()).WithPrefix(), Protocol: ipv4.ProtocolNumber, } if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ diff --git a/service/service_test.go b/service/service_test.go index e9fceef6f..613758e19 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,8 +10,8 @@ import ( "dario.cat/mergo" "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/cert_test" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" "gopkg.in/yaml.v2" ) @@ -19,7 +19,7 @@ import ( type m map[string]interface{} func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) + _, _, myPrivKey, myPEM := cert_test.NewTestCert(cert.Version2, cert.Curve_CURVE25519, caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{netip.PrefixFrom(udpIp, 24)}, nil, []string{}) caB, err := caCrt.MarshalPEM() if err != nil { panic(err) @@ -79,7 +79,7 @@ func newSimpleService(caCrt cert.Certificate, caKey []byte, name string, udpIp n } func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := cert_test.NewTestCaCert(cert.Version2, cert.Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ diff --git a/ssh.go b/ssh.go index 881ee4696..1062a77f9 100644 --- a/ssh.go +++ b/ssh.go @@ -320,7 +320,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-cert", - ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn ip", + ShortDescription: "Prints the current certificate being used or the certificate for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintCertFlags{} @@ -336,7 +336,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "print-tunnel", - ShortDescription: "Prints json details about a tunnel for the provided vpn ip", + ShortDescription: "Prints json details about a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshPrintTunnelFlags{} @@ -364,7 +364,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "change-remote", - ShortDescription: "Changes the remote address used in the tunnel for the provided vpn ip", + ShortDescription: "Changes the remote address used in the tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshChangeRemoteFlags{} @@ -378,7 +378,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "close-tunnel", - ShortDescription: "Closes a tunnel for the provided vpn ip", + ShortDescription: "Closes a tunnel for the provided vpn addr", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) s := sshCloseTunnelFlags{} @@ -392,7 +392,7 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "create-tunnel", - ShortDescription: "Creates a tunnel for the provided vpn ip and address", + ShortDescription: "Creates a tunnel for the provided vpn address", Help: "The lighthouses will be queried for real addresses but you can provide one as well.", Flags: func() (*flag.FlagSet, interface{}) { fl := flag.NewFlagSet("", flag.ContinueOnError) @@ -407,8 +407,8 @@ func attachCommands(l *logrus.Logger, c *config.C, ssh *sshd.SSHServer, f *Inter ssh.RegisterCommand(&sshd.Command{ Name: "query-lighthouse", - ShortDescription: "Query the lighthouses for the provided vpn ip", - Help: "This command is asynchronous. Only currently known udp ips will be printed.", + ShortDescription: "Query the lighthouses for the provided vpn address", + Help: "This command is asynchronous. Only currently known udp addresses will be printed.", Callback: func(fs interface{}, a []string, w sshd.StringWriter) error { return sshQueryLighthouse(f, fs, a, w) }, @@ -430,7 +430,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } sort.Slice(hm, func(i, j int) bool { - return hm[i].VpnIp.Compare(hm[j].VpnIp) < 0 + return hm[i].VpnAddrs[0].Compare(hm[j].VpnAddrs[0]) < 0 }) if fs.Json || fs.Pretty { @@ -447,7 +447,7 @@ func sshListHostMap(hl controlHostLister, a interface{}, w sshd.StringWriter) er } else { for _, v := range hm { - err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, v.RemoteAddrs)) + err := w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddrs, v.RemoteAddrs)) if err != nil { return err } @@ -465,8 +465,8 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr } type lighthouseInfo struct { - VpnIp string `json:"vpnIp"` - Addrs *CacheMap `json:"addrs"` + VpnAddr string `json:"vpnAddr"` + Addrs *CacheMap `json:"addrs"` } lightHouse.RLock() @@ -474,15 +474,15 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr x := 0 for k, v := range lightHouse.addrMap { addrMap[x] = lighthouseInfo{ - VpnIp: k.String(), - Addrs: v.CopyCache(), + VpnAddr: k.String(), + Addrs: v.CopyCache(), } x++ } lightHouse.RUnlock() sort.Slice(addrMap, func(i, j int) bool { - return strings.Compare(addrMap[i].VpnIp, addrMap[j].VpnIp) < 0 + return strings.Compare(addrMap[i].VpnAddr, addrMap[j].VpnAddr) < 0 }) if fs.Json || fs.Pretty { @@ -503,7 +503,7 @@ func sshListLighthouseMap(lightHouse *LightHouse, a interface{}, w sshd.StringWr if err != nil { return err } - err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnIp, string(b))) + err = w.WriteLine(fmt.Sprintf("%s: %s", v.VpnAddr, string(b))) if err != nil { return err } @@ -541,20 +541,20 @@ func sshVersion(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter func sshQueryLighthouse(ifce *Interface, fs interface{}, a []string, w sshd.StringWriter) error { if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } var cm *CacheMap - rl := ifce.lightHouse.Query(vpnIp) + rl := ifce.lightHouse.Query(vpnAddr) if rl != nil { cm = rl.CopyCache() } @@ -569,21 +569,21 @@ func sshCloseTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } if !flags.LocalOnly { @@ -610,24 +610,24 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already exists")) } - hostInfo = ifce.handshakeManager.QueryVpnIp(vpnIp) + hostInfo = ifce.handshakeManager.QueryVpnAddr(vpnAddr) if hostInfo != nil { return w.WriteLine(fmt.Sprintf("Tunnel already handshaking")) } @@ -640,7 +640,7 @@ func sshCreateTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringW } } - hostInfo = ifce.handshakeManager.StartHandshake(vpnIp, nil) + hostInfo = ifce.handshakeManager.StartHandshake(vpnAddr, nil) if addr.IsValid() { hostInfo.SetRemote(addr) } @@ -656,7 +656,7 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } if flags.Address == "" { @@ -668,18 +668,18 @@ func sshChangeRemote(ifce *Interface, fs interface{}, a []string, w sshd.StringW return w.WriteLine("Address could not be parsed") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn address could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn address: %v", a[0])) } hostInfo.SetRemote(addr) @@ -785,20 +785,20 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return nil } - cert := ifce.pki.GetCertState().Certificate + cert := ifce.pki.getCertState().GetDefaultCertificate() if len(a) > 0 { - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } cert = hostInfo.GetCert().Certificate @@ -856,15 +856,15 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr Error error Type string State string - PeerIp netip.Addr + PeerAddr netip.Addr LocalIndex uint32 RemoteIndex uint32 RelayedThrough []netip.Addr } type RelayOutput struct { - NebulaIp netip.Addr - RelayForIps []RelayFor + NebulaAddr netip.Addr + RelayForAddrs []RelayFor } type CmdOutput struct { @@ -880,16 +880,16 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } for k, v := range relays { - ro := RelayOutput{NebulaIp: v.vpnIp} + ro := RelayOutput{NebulaAddr: v.vpnAddrs[0]} co.Relays = append(co.Relays, &ro) - relayHI := ifce.hostMap.QueryVpnIp(v.vpnIp) + relayHI := ifce.hostMap.QueryVpnAddr(v.vpnAddrs[0]) if relayHI == nil { - ro.RelayForIps = append(ro.RelayForIps, RelayFor{Error: errors.New("could not find hostinfo")}) + ro.RelayForAddrs = append(ro.RelayForAddrs, RelayFor{Error: errors.New("could not find hostinfo")}) continue } - for _, vpnIp := range relayHI.relayState.CopyRelayForIps() { + for _, vpnAddr := range relayHI.relayState.CopyRelayForIps() { rf := RelayFor{Error: nil} - r, ok := relayHI.relayState.GetRelayForByIp(vpnIp) + r, ok := relayHI.relayState.GetRelayForByAddr(vpnAddr) if ok { t := "" switch r.Type { @@ -913,19 +913,19 @@ func sshPrintRelays(ifce *Interface, fs interface{}, a []string, w sshd.StringWr rf.LocalIndex = r.LocalIndex rf.RemoteIndex = r.RemoteIndex - rf.PeerIp = r.PeerIp + rf.PeerAddr = r.PeerAddr rf.Type = t rf.State = s if rf.LocalIndex != k { rf.Error = fmt.Errorf("hostmap LocalIndex '%v' does not match RelayState LocalIndex", k) } } - relayedHI := ifce.hostMap.QueryVpnIp(vpnIp) + relayedHI := ifce.hostMap.QueryVpnAddr(vpnAddr) if relayedHI != nil { rf.RelayedThrough = append(rf.RelayedThrough, relayedHI.relayState.CopyRelayIps()...) } - ro.RelayForIps = append(ro.RelayForIps, rf) + ro.RelayForAddrs = append(ro.RelayForAddrs, rf) } } err := enc.Encode(co) @@ -943,21 +943,21 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr } if len(a) == 0 { - return w.WriteLine("No vpn ip was provided") + return w.WriteLine("No vpn address was provided") } - vpnIp, err := netip.ParseAddr(a[0]) + vpnAddr, err := netip.ParseAddr(a[0]) if err != nil { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - if !vpnIp.IsValid() { - return w.WriteLine(fmt.Sprintf("The provided vpn ip could not be parsed: %s", a[0])) + if !vpnAddr.IsValid() { + return w.WriteLine(fmt.Sprintf("The provided vpn addr could not be parsed: %s", a[0])) } - hostInfo := ifce.hostMap.QueryVpnIp(vpnIp) + hostInfo := ifce.hostMap.QueryVpnAddr(vpnAddr) if hostInfo == nil { - return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) + return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn addr: %v", a[0])) } enc := json.NewEncoder(w.GetWriter()) @@ -971,13 +971,15 @@ func sshPrintTunnel(ifce *Interface, fs interface{}, a []string, w sshd.StringWr func sshDeviceInfo(ifce *Interface, fs interface{}, w sshd.StringWriter) error { data := struct { - Name string `json:"name"` - Cidr string `json:"cidr"` + Name string `json:"name"` + Cidr []netip.Prefix `json:"cidr"` }{ Name: ifce.inside.Name(), - Cidr: ifce.inside.Cidr().String(), + Cidr: make([]netip.Prefix, len(ifce.inside.Networks())), } + copy(data.Cidr, ifce.inside.Networks()) + flags, ok := fs.(*sshDeviceInfoFlags) if !ok { return fmt.Errorf("internal error: expected flags to be sshDeviceInfoFlags but was %+v", fs) diff --git a/test/tun.go b/test/tun.go index fbf58295a..b29d61c7b 100644 --- a/test/tun.go +++ b/test/tun.go @@ -16,8 +16,8 @@ func (NoopTun) Activate() error { return nil } -func (NoopTun) Cidr() netip.Prefix { - return netip.Prefix{} +func (NoopTun) Networks() []netip.Prefix { + return []netip.Prefix{} } func (NoopTun) Name() string { diff --git a/timeout_test.go b/timeout_test.go index 4c6364ef5..db36fec73 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -116,10 +116,10 @@ func TestTimerWheel_Purge(t *testing.T) { assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ - {LocalIP: netip.MustParseAddr("0.0.0.1")}, - {LocalIP: netip.MustParseAddr("0.0.0.2")}, - {LocalIP: netip.MustParseAddr("0.0.0.3")}, - {LocalIP: netip.MustParseAddr("0.0.0.4")}, + {LocalAddr: netip.MustParseAddr("0.0.0.1")}, + {LocalAddr: netip.MustParseAddr("0.0.0.2")}, + {LocalAddr: netip.MustParseAddr("0.0.0.3")}, + {LocalAddr: netip.MustParseAddr("0.0.0.4")}, } tw.Add(fps[0], time.Second*1) diff --git a/udp/conn.go b/udp/conn.go index fa4e44304..895b0df35 100644 --- a/udp/conn.go +++ b/udp/conn.go @@ -4,28 +4,19 @@ import ( "net/netip" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) const MTU = 9001 type EncReader func( addr netip.AddrPort, - out []byte, - packet []byte, - header *header.H, - fwPacket *firewall.Packet, - lhh LightHouseHandlerFunc, - nb []byte, - q int, - localCache firewall.ConntrackCache, + payload []byte, ) type Conn interface { Rebind() error LocalAddr() (netip.AddrPort, error) - ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) + ListenOut(r EncReader) WriteTo(b []byte, addr netip.AddrPort) error ReloadConfig(c *config.C) Close() error @@ -39,7 +30,7 @@ func (NoopConn) Rebind() error { func (NoopConn) LocalAddr() (netip.AddrPort, error) { return netip.AddrPort{}, nil } -func (NoopConn) ListenOut(_ EncReader, _ LightHouseHandlerFunc, _ *firewall.ConntrackCacheTicker, _ int) { +func (NoopConn) ListenOut(_ EncReader) { return } func (NoopConn) WriteTo(_ []byte, _ netip.AddrPort) error { diff --git a/udp/temp.go b/udp/temp.go deleted file mode 100644 index b281906f5..000000000 --- a/udp/temp.go +++ /dev/null @@ -1,10 +0,0 @@ -package udp - -import ( - "net/netip" -) - -//TODO: The items in this file belong in their own packages but doing that in a single PR is a nightmare - -// TODO: IPV6-WORK this can likely be removed now -type LightHouseHandlerFunc func(rAddr netip.AddrPort, vpnIp netip.Addr, p []byte) diff --git a/udp/udp_generic.go b/udp/udp_generic.go index 2d8453694..99a3eca21 100644 --- a/udp/udp_generic.go +++ b/udp/udp_generic.go @@ -15,8 +15,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" ) type GenericConn struct { @@ -72,12 +70,8 @@ type rawMessage struct { Len uint32 } -func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *GenericConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -87,16 +81,6 @@ func (u *GenericConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *f return } - r( - netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(rua.Addr().Unmap(), rua.Port()), buffer[:n]) } } diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76ee2..36ab67c50 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -14,8 +14,6 @@ import ( "github.com/rcrowley/go-metrics" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" "golang.org/x/sys/unix" ) @@ -120,15 +118,9 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } } -func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} +func (u *StdConn) ListenOut(r EncReader) { var ip netip.Addr - nb := make([]byte, 12, 12) - //TODO: should we track this? - //metric := metrics.GetOrRegisterHistogram("test.batch_read", nil, metrics.NewExpDecaySample(1028, 0.015)) msgs, buffers, names := u.PrepareRawMessages(u.batch) read := u.ReadMulti if u.batch == 1 { @@ -142,26 +134,14 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - //metric.Update(int64(n)) for i := 0; i < n; i++ { + // Its ok to skip the ok check here, the slicing is the only error that can occur and it will panic if u.isV4 { ip, _ = netip.AddrFromSlice(names[i][4:8]) - //TODO: IPV6-WORK what is not ok? } else { ip, _ = netip.AddrFromSlice(names[i][8:24]) - //TODO: IPV6-WORK what is not ok? } - r( - netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), - plaintext[:0], - buffers[i][:msgs[i].Len], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(ip.Unmap(), binary.BigEndian.Uint16(names[i][2:4])), buffers[i][:msgs[i].Len]) } } } diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ee7e1e002..585b642bb 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -18,9 +18,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" - "github.com/slackhq/nebula/header" - "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/conn/winrio" ) @@ -118,12 +115,8 @@ func (u *RIOConn) bind(sa windows.Sockaddr) error { return nil } -func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) +func (u *RIOConn) ListenOut(r EncReader) { buffer := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) for { // Just read one packet at a time @@ -133,17 +126,7 @@ func (u *RIOConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - r( - netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), - plaintext[:0], - buffer[:n], - h, - fwPacket, - lhf, - nb, - q, - cache.Get(u.l), - ) + r(netip.AddrPortFrom(netip.AddrFrom16(rua.Addr).Unmap(), (rua.Port>>8)|((rua.Port&0xff)<<8)), buffer[:n]) } } diff --git a/udp/udp_tester.go b/udp/udp_tester.go index f03a3535f..8d5e6c14a 100644 --- a/udp/udp_tester.go +++ b/udp/udp_tester.go @@ -10,7 +10,6 @@ import ( "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/header" ) @@ -107,18 +106,13 @@ func (u *TesterConn) WriteTo(b []byte, addr netip.AddrPort) error { return nil } -func (u *TesterConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { - plaintext := make([]byte, MTU) - h := &header.H{} - fwPacket := &firewall.Packet{} - nb := make([]byte, 12, 12) - +func (u *TesterConn) ListenOut(r EncReader) { for { p, ok := <-u.RxPackets if !ok { return } - r(p.From, plaintext[:0], p.Data, h, fwPacket, lhf, nb, q, cache.Get(u.l)) + r(p.From, p.Data) } }