Skip to content

Commit

Permalink
Use generics for CIDRTrees to avoid casting issues (#1004)
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Nov 2, 2023
1 parent a44e1b8 commit 5181cb0
Show file tree
Hide file tree
Showing 21 changed files with 264 additions and 247 deletions.
43 changes: 14 additions & 29 deletions allow_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ import (

type AllowList struct {
// The values of this cidrTree are `bool`, signifying allow/deny
cidrTree *cidr.Tree6
cidrTree *cidr.Tree6[bool]
}

type RemoteAllowList struct {
AllowList *AllowList

// Inside Range Specific, keys of this tree are inside CIDRs and values
// are *AllowList
insideAllowLists *cidr.Tree6
insideAllowLists *cidr.Tree6[*AllowList]
}

type LocalAllowList struct {
Expand Down Expand Up @@ -88,7 +88,7 @@ func newAllowList(k string, raw interface{}, handleKey func(key string, value in
return nil, fmt.Errorf("config `%s` has invalid type: %T", k, raw)
}

tree := cidr.NewTree6()
tree := cidr.NewTree6[bool]()

// Keep track of the rules we have added for both ipv4 and ipv6
type allowListRules struct {
Expand Down Expand Up @@ -218,13 +218,13 @@ func getAllowListInterfaces(k string, v interface{}) ([]AllowListNameRule, error
return nameRules, nil
}

func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6, error) {
func getRemoteAllowRanges(c *config.C, k string) (*cidr.Tree6[*AllowList], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}

remoteAllowRanges := cidr.NewTree6()
remoteAllowRanges := cidr.NewTree6[*AllowList]()

rawMap, ok := value.(map[interface{}]interface{})
if !ok {
Expand Down Expand Up @@ -257,41 +257,26 @@ func (al *AllowList) Allow(ip net.IP) bool {
return true
}

result := al.cidrTree.MostSpecificContains(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
_, result := al.cidrTree.MostSpecificContains(ip)
return result
}

func (al *AllowList) AllowIpV4(ip iputil.VpnIp) bool {
if al == nil {
return true
}

result := al.cidrTree.MostSpecificContainsIpV4(ip)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
_, result := al.cidrTree.MostSpecificContainsIpV4(ip)
return result
}

func (al *AllowList) AllowIpV6(hi, lo uint64) bool {
if al == nil {
return true
}

result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
switch v := result.(type) {
case bool:
return v
default:
panic(fmt.Errorf("invalid state, allowlist returned: %T %v", result, result))
}
_, result := al.cidrTree.MostSpecificContainsIpV6(hi, lo)
return result
}

func (al *LocalAllowList) Allow(ip net.IP) bool {
Expand Down Expand Up @@ -352,9 +337,9 @@ func (al *RemoteAllowList) AllowIpV6(vpnIp iputil.VpnIp, hi, lo uint64) bool {

func (al *RemoteAllowList) getInsideAllowList(vpnIp iputil.VpnIp) *AllowList {
if al.insideAllowLists != nil {
inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
if inside != nil {
return inside.(*AllowList)
ok, inside := al.insideAllowLists.MostSpecificContainsIpV4(vpnIp)
if ok {
return inside
}
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion allow_list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
func TestAllowList_Allow(t *testing.T) {
assert.Equal(t, true, ((*AllowList)(nil)).Allow(net.ParseIP("1.1.1.1")))

tree := cidr.NewTree6()
tree := cidr.NewTree6[bool]()
tree.AddCIDR(cidr.Parse("0.0.0.0/0"), true)
tree.AddCIDR(cidr.Parse("10.0.0.0/8"), false)
tree.AddCIDR(cidr.Parse("10.42.42.42/32"), true)
Expand Down
4 changes: 2 additions & 2 deletions calculated_remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ func (c *calculatedRemote) Apply(ip iputil.VpnIp) *Ip4AndPort {
return &Ip4AndPort{Ip: uint32(masked), Port: c.port}
}

func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4, error) {
func NewCalculatedRemotesFromConfig(c *config.C, k string) (*cidr.Tree4[[]*calculatedRemote], error) {
value := c.Get(k)
if value == nil {
return nil, nil
}

calculatedRemotes := cidr.NewTree4()
calculatedRemotes := cidr.NewTree4[[]*calculatedRemote]()

rawMap, ok := value.(map[any]any)
if !ok {
Expand Down
62 changes: 34 additions & 28 deletions cidr/tree4.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,36 @@ import (
"github.com/slackhq/nebula/iputil"
)

type Node struct {
left *Node
right *Node
parent *Node
value interface{}
type Node[T any] struct {
left *Node[T]
right *Node[T]
parent *Node[T]
hasValue bool
value T
}

type entry struct {
type entry[T any] struct {
CIDR *net.IPNet
Value *interface{}
Value T
}

type Tree4 struct {
root *Node
list []entry
type Tree4[T any] struct {
root *Node[T]
list []entry[T]
}

const (
startbit = iputil.VpnIp(0x80000000)
)

func NewTree4() *Tree4 {
tree := new(Tree4)
tree.root = &Node{}
tree.list = []entry{}
func NewTree4[T any]() *Tree4[T] {
tree := new(Tree4[T])
tree.root = &Node[T]{}
tree.list = []entry[T]{}
return tree
}

func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
func (tree *Tree4[T]) AddCIDR(cidr *net.IPNet, val T) {
bit := startbit
node := tree.root
next := tree.root
Expand Down Expand Up @@ -68,14 +69,15 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {
}
}

tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
node.value = val
node.hasValue = true
return
}

// Build up the rest of the tree we don't already have
for bit&mask != 0 {
next = &Node{}
next = &Node[T]{}
next.parent = node

if ip&bit != 0 {
Expand All @@ -90,17 +92,18 @@ func (tree *Tree4) AddCIDR(cidr *net.IPNet, val interface{}) {

// Final node marks our cidr, set the value
node.value = val
tree.list = append(tree.list, entry{CIDR: cidr, Value: &val})
node.hasValue = true
tree.list = append(tree.list, entry[T]{CIDR: cidr, Value: val})
}

// Contains finds the first match, which may be the least specific
func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {
func (tree *Tree4[T]) Contains(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root

for node != nil {
if node.value != nil {
return node.value
if node.hasValue {
return true, node.value
}

if ip&bit != 0 {
Expand All @@ -113,17 +116,18 @@ func (tree *Tree4) Contains(ip iputil.VpnIp) (value interface{}) {

}

return value
return false, value
}

// MostSpecificContains finds the most specific match
func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
func (tree *Tree4[T]) MostSpecificContains(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root

for node != nil {
if node.value != nil {
if node.hasValue {
value = node.value
ok = true
}

if ip&bit != 0 {
Expand All @@ -135,11 +139,12 @@ func (tree *Tree4) MostSpecificContains(ip iputil.VpnIp) (value interface{}) {
bit >>= 1
}

return value
return ok, value
}

// Match finds the most specific match
func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {
// TODO this is exact match
func (tree *Tree4[T]) Match(ip iputil.VpnIp) (ok bool, value T) {
bit := startbit
node := tree.root
lastNode := node
Expand All @@ -157,11 +162,12 @@ func (tree *Tree4) Match(ip iputil.VpnIp) (value interface{}) {

if bit == 0 && lastNode != nil {
value = lastNode.value
ok = true
}
return value
return ok, value
}

// List will return all CIDRs and their current values. Do not modify the contents!
func (tree *Tree4) List() []entry {
func (tree *Tree4[T]) List() []entry[T] {
return tree.list
}
Loading

0 comments on commit 5181cb0

Please sign in to comment.