Skip to content

Commit

Permalink
optimize generateACLPeerCacheMap (juanfont#1377)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkrivanec authored Apr 26, 2023
1 parent 6215eb6 commit d011373
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 34 deletions.
15 changes: 6 additions & 9 deletions acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,23 +163,20 @@ func (h *Headscale) UpdateACLRules() error {
// generateACLPeerCacheMap takes a list of Tailscale filter rules and generates a map
// of which Sources ("*" and IPs) can access destinations. This is to speed up the
// process of generating MapResponses when deciding which Peers to inform nodes about.
func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string]map[string]struct{} {
aclCachePeerMap := make(map[string]map[string]struct{})
func generateACLPeerCacheMap(rules []tailcfg.FilterRule) map[string][]string {
aclCachePeerMap := make(map[string][]string)
for _, rule := range rules {
for _, srcIP := range rule.SrcIPs {
for _, ip := range expandACLPeerAddr(srcIP) {
if data, ok := aclCachePeerMap[ip]; ok {
for _, dstPort := range rule.DstPorts {
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
data[dstIP] = struct{}{}
}
data = append(data, dstPort.IP)
}
aclCachePeerMap[ip] = data
} else {
dstPortsMap := make(map[string]struct{}, len(rule.DstPorts))
dstPortsMap := make([]string, 0)
for _, dstPort := range rule.DstPorts {
for _, dstIP := range expandACLPeerAddr(dstPort.IP) {
dstPortsMap[dstIP] = struct{}{}
}
dstPortsMap = append(dstPortsMap, dstPort.IP)
}
aclCachePeerMap[ip] = dstPortsMap
}
Expand Down
2 changes: 1 addition & 1 deletion app.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ type Headscale struct {
aclPolicy *ACLPolicy
aclRules []tailcfg.FilterRule
aclPeerCacheMapRW sync.RWMutex
aclPeerCacheMap map[string]map[string]struct{}
aclPeerCacheMap map[string][]string
sshPolicy *tailcfg.SSHPolicy

lastStateChange *xsync.MapOf[string, time.Time]
Expand Down
73 changes: 49 additions & 24 deletions machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"database/sql/driver"
"errors"
"fmt"
"net"
"net/netip"
"sort"
"strconv"
Expand Down Expand Up @@ -172,7 +173,7 @@ func filterMachinesByACL(
machine *Machine,
machines Machines,
lock *sync.RWMutex,
aclPeerCacheMap map[string]map[string]struct{},
aclPeerCacheMap map[string][]string,
) Machines {
log.Trace().
Caller().
Expand All @@ -197,46 +198,63 @@ func filterMachinesByACL(

if dstMap, ok := aclPeerCacheMap["*"]; ok {
// match source and all destination
if _, dstOk := dstMap["*"]; dstOk {
peers[peer.ID] = peer

continue
for _, dst := range dstMap {
if dst == "*" {
peers[peer.ID] = peer

continue
}
}

// match source and all destination
for _, peerIP := range peerIPs {
if _, dstOk := dstMap[peerIP]; dstOk {
peers[peer.ID] = peer
for _, dst := range dstMap {
_, cdr, _ := net.ParseCIDR(dst)
ip := net.ParseIP(peerIP)
if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
peers[peer.ID] = peer

continue
continue
}
}
}

// match all sources and source
for _, machineIP := range machineIPs {
if _, dstOk := dstMap[machineIP]; dstOk {
peers[peer.ID] = peer
for _, dst := range dstMap {
_, cdr, _ := net.ParseCIDR(dst)
ip := net.ParseIP(machineIP)
if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
peers[peer.ID] = peer

continue
continue
}
}
}
}

for _, machineIP := range machineIPs {
if dstMap, ok := aclPeerCacheMap[machineIP]; ok {
// match source and all destination
if _, dstOk := dstMap["*"]; dstOk {
peers[peer.ID] = peer
for _, dst := range dstMap {
if dst == "*" {
peers[peer.ID] = peer

continue
continue
}
}

// match source and destination
for _, peerIP := range peerIPs {
if _, dstOk := dstMap[peerIP]; dstOk {
peers[peer.ID] = peer

continue
for _, dst := range dstMap {
_, cdr, _ := net.ParseCIDR(dst)
ip := net.ParseIP(peerIP)
if dst == peerIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
peers[peer.ID] = peer

continue
}
}
}
}
Expand All @@ -245,17 +263,24 @@ func filterMachinesByACL(
for _, peerIP := range peerIPs {
if dstMap, ok := aclPeerCacheMap[peerIP]; ok {
// match source and all destination
if _, dstOk := dstMap["*"]; dstOk {
peers[peer.ID] = peer
for _, dst := range dstMap {
if dst == "*" {
peers[peer.ID] = peer

continue
continue
}
}

// match return path
for _, machineIP := range machineIPs {
if _, dstOk := dstMap[machineIP]; dstOk {
peers[peer.ID] = peer

continue
for _, dst := range dstMap {
_, cdr, _ := net.ParseCIDR(dst)
ip := net.ParseIP(machineIP)
if dst == machineIP || (cdr != nil && ip != nil && cdr.Contains(ip)) {
peers[peer.ID] = peer

continue
}
}
}
}
Expand Down

0 comments on commit d011373

Please sign in to comment.