diff --git a/firewall.go b/firewall.go index 06b8e8589..d1e306348 100644 --- a/firewall.go +++ b/firewall.go @@ -432,8 +432,8 @@ func (f *Firewall) Drop(fp firewall.Packet, incoming bool, h *HostInfo, caPool * } // Make sure remote address matches nebula certificate - if remoteCidr := h.remoteCidr; remoteCidr != nil { - _, ok := remoteCidr.Lookup(fp.RemoteAddr) + if h.networks != nil { + _, ok := h.networks.Lookup(fp.RemoteAddr) if !ok { f.metrics(incoming).droppedRemoteAddr.Inc(1) return ErrInvalidRemoteIP diff --git a/firewall_test.go b/firewall_test.go index 1bdfe6f93..c093a6e7b 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -152,7 +152,7 @@ func TestFirewall_Drop(t *testing.T) { }, vpnAddrs: []netip.Addr{netip.MustParseAddr("1.2.3.4")}, } - h.CreateRemoteCIDR(&c) + h.buildNetworks(&c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -332,7 +332,7 @@ func TestFirewall_Drop2(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(c.Certificate) + h.buildNetworks(c.Certificate) c1 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -346,7 +346,7 @@ func TestFirewall_Drop2(t *testing.T) { peerCert: &c1, }, } - h1.CreateRemoteCIDR(c1.Certificate) + h1.buildNetworks(c1.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -394,7 +394,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h1.CreateRemoteCIDR(c1.Certificate) + h1.buildNetworks(c1.Certificate) c2 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -409,7 +409,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h2.CreateRemoteCIDR(c2.Certificate) + h2.buildNetworks(c2.Certificate) c3 := cert.CachedCertificate{ Certificate: &dummyCert{ @@ -424,7 +424,7 @@ func TestFirewall_Drop3(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h3.CreateRemoteCIDR(c3.Certificate) + h3.buildNetworks(c3.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) @@ -471,7 +471,7 @@ func TestFirewall_DropConntrackReload(t *testing.T) { }, vpnAddrs: []netip.Addr{network.Addr()}, } - h.CreateRemoteCIDR(c.Certificate) + h.buildNetworks(c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) assert.Nil(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) diff --git a/handshake_ix.go b/handshake_ix.go index d5c39f95a..e47bdf08f 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -294,7 +294,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnAddrs) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert.Certificate) + hostinfo.buildNetworks(remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -570,7 +570,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha } // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert.Certificate) + hostinfo.buildNetworks(remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) diff --git a/hostmap.go b/hostmap.go index dc16af208..3b562b2f1 100644 --- a/hostmap.go +++ b/hostmap.go @@ -191,8 +191,10 @@ type HostInfo struct { localIndexId uint32 vpnAddrs []netip.Addr recvError atomic.Uint32 - remoteCidr *bart.Table[struct{}] //TODO: rename `vpnNetworks` - relayState RelayState + + // networks are both all vpn and unsafe networks assigned to this host + networks *bart.Table[struct{}] + relayState RelayState // HandshakePacket records the packets used to create this hostinfo // We need these to avoid replayed handshake packets creating new hostinfos which causes churn @@ -652,21 +654,20 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { +func (i *HostInfo) buildNetworks(c cert.Certificate) { if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed return } - remoteCidr := new(bart.Table[struct{}]) + i.networks = new(bart.Table[struct{}]) for _, network := range c.Networks() { - remoteCidr.Insert(network, struct{}{}) + i.networks.Insert(network, struct{}{}) } for _, network := range c.UnsafeNetworks() { - remoteCidr.Insert(network, struct{}{}) + i.networks.Insert(network, struct{}{}) } - i.remoteCidr = remoteCidr } func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry {