Skip to content

Commit

Permalink
use IPSet in acls instead of string slice
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby authored and juanfont committed May 3, 2023
1 parent 1a7ae11 commit 735b185
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 104 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- Profiles are continously generated in our integration tests.
- Fix systemd service file location in `.deb` packages [#1391](https://github.com/juanfont/headscale/pull/1391)
- Improvements on Noise implementation [#1379](https://github.com/juanfont/headscale/pull/1379)
- Replace node filter logic, ensuring nodes with access can see eachother [#1381](https://github.com/juanfont/headscale/pull/1381)

## 0.22.1 (2023-04-20)

Expand Down
186 changes: 120 additions & 66 deletions acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ import (
"time"

"github.com/rs/zerolog/log"
"github.com/samber/lo"
"github.com/tailscale/hujson"
"go4.org/netipx"
"gopkg.in/yaml.v3"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
Expand Down Expand Up @@ -272,21 +272,41 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {

principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources {
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)

return nil, err
}
for _, expandedSrc := range expandedSrcs {
if isWildcard(rawSrc) {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc,
Any: true,
})
} else if isGroup(rawSrc) {
users, err := h.aclPolicy.getUsersInGroup(rawSrc, h.cfg.OIDC.StripEmaildomain)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)

return nil, err
}

for _, user := range users {
principals = append(principals, &tailcfg.SSHPrincipal{
UserLogin: user,
})
}
} else {
expandedSrcs, err := h.aclPolicy.expandAlias(
machines,
rawSrc,
h.cfg.OIDC.StripEmaildomain,
)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Source %d", index, innerIndex)

return nil, err
}
for _, expandedSrc := range expandedSrcs.Prefixes() {
principals = append(principals, &tailcfg.SSHPrincipal{
NodeIP: expandedSrc.Addr().String(),
})
}
}
}

Expand All @@ -295,10 +315,9 @@ func (h *Headscale) generateSSHRules() ([]*tailcfg.SSHRule, error) {
userMap[user] = "="
}
rules = append(rules, &tailcfg.SSHRule{
RuleExpires: nil,
Principals: principals,
SSHUsers: userMap,
Action: &action,
Principals: principals,
SSHUsers: userMap,
Action: &action,
})
}

Expand Down Expand Up @@ -329,7 +348,18 @@ func (pol *ACLPolicy) getIPsFromSource(
machines []Machine,
stripEmaildomain bool,
) ([]string, error) {
return pol.expandAlias(machines, src, stripEmaildomain)
ipSet, err := pol.expandAlias(machines, src, stripEmaildomain)
if err != nil {
return []string{}, err
}

prefixes := []string{}

for _, prefix := range ipSet.Prefixes() {
prefixes = append(prefixes, prefix.String())
}

return prefixes, nil
}

// getNetPortRangeFromDestination returns a set of tailcfg.NetPortRange
Expand Down Expand Up @@ -397,11 +427,11 @@ func (pol *ACLPolicy) getNetPortRangeFromDestination(
}

dests := []tailcfg.NetPortRange{}
for _, d := range expanded {
for _, p := range *ports {
for _, dest := range expanded.Prefixes() {
for _, port := range *ports {
pr := tailcfg.NetPortRange{
IP: d,
Ports: p,
IP: dest.String(),
Ports: port,
}
dests = append(dests, pr)
}
Expand Down Expand Up @@ -472,28 +502,30 @@ func (pol *ACLPolicy) expandAlias(
machines Machines,
alias string,
stripEmailDomain bool,
) ([]string, error) {
if alias == "*" {
return []string{"*"}, nil
) (*netipx.IPSet, error) {
if isWildcard(alias) {
return parseIPSet("*", nil)
}

build := netipx.IPSetBuilder{}

log.Debug().
Str("alias", alias).
Msg("Expanding")

// if alias is a group
if strings.HasPrefix(alias, "group:") {
if isGroup(alias) {
return pol.getIPsFromGroup(alias, machines, stripEmailDomain)
}

// if alias is a tag
if strings.HasPrefix(alias, "tag:") {
if isTag(alias) {
return pol.getIPsFromTag(alias, machines, stripEmailDomain)
}

// if alias is a user
if ips := pol.getIPsForUser(alias, machines, stripEmailDomain); len(ips) > 0 {
return ips, nil
if ips, err := pol.getIPsForUser(alias, machines, stripEmailDomain); ips != nil {
return ips, err
}

// if alias is an host
Expand All @@ -516,7 +548,7 @@ func (pol *ACLPolicy) expandAlias(

log.Warn().Msgf("No IPs found with the alias %v", alias)

return []string{}, nil
return build.IPSet()
}

// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones
Expand Down Expand Up @@ -561,7 +593,7 @@ func excludeCorrectlyTaggedNodes(
}

func expandPorts(portsStr string, needsWildcard bool) (*[]tailcfg.PortRange, error) {
if portsStr == "*" {
if isWildcard(portsStr) {
return &[]tailcfg.PortRange{
{First: portRangeBegin, Last: portRangeEnd},
}, nil
Expand Down Expand Up @@ -636,7 +668,7 @@ func getTagOwners(
)
}
for _, owner := range ows {
if strings.HasPrefix(owner, "group:") {
if isGroup(owner) {
gs, err := pol.getUsersInGroup(owner, stripEmailDomain)
if err != nil {
return []string{}, err
Expand Down Expand Up @@ -667,7 +699,7 @@ func (pol *ACLPolicy) getUsersInGroup(
)
}
for _, group := range aclGroups {
if strings.HasPrefix(group, "group:") {
if isGroup(group) {
return []string{}, fmt.Errorf(
"%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups",
errInvalidGroup,
Expand All @@ -691,52 +723,53 @@ func (pol *ACLPolicy) getIPsFromGroup(
group string,
machines Machines,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}

users, err := pol.getUsersInGroup(group, stripEmailDomain)
if err != nil {
return ips, err
return &netipx.IPSet{}, err
}
for _, n := range users {
nodes := filterMachinesByUser(machines, n)
for _, node := range nodes {
ips = append(ips, node.IPAddresses.ToStringSlice()...)
for _, user := range users {
filteredMachines := filterMachinesByUser(machines, user)
for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}
}

return ips, nil
return build.IPSet()
}

func (pol *ACLPolicy) getIPsFromTag(
alias string,
machines Machines,
stripEmailDomain bool,
) ([]string, error) {
ips := []string{}
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}

// check for forced tags
for _, machine := range machines {
if contains(machine.ForcedTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
}

// find tag owners
owners, err := getTagOwners(pol, alias, stripEmailDomain)
if err != nil {
if errors.Is(err, errInvalidTag) {
if len(ips) == 0 {
return ips, fmt.Errorf(
ipSet, _ := build.IPSet()
if len(ipSet.Prefixes()) == 0 {
return ipSet, fmt.Errorf(
"%w. %v isn't owned by a TagOwner and no forced tags are defined",
errInvalidTag,
alias,
)
}

return ips, nil
return build.IPSet()
} else {
return ips, err
return nil, err
}
}

Expand All @@ -746,64 +779,85 @@ func (pol *ACLPolicy) getIPsFromTag(
for _, machine := range machines {
hi := machine.GetHostInfo()
if contains(hi.RequestTags, alias) {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
}
}

return ips, nil
return build.IPSet()
}

func (pol *ACLPolicy) getIPsForUser(
user string,
machines Machines,
stripEmailDomain bool,
) []string {
ips := []string{}
) (*netipx.IPSet, error) {
build := netipx.IPSetBuilder{}

nodes := filterMachinesByUser(machines, user)
nodes = excludeCorrectlyTaggedNodes(pol, nodes, user, stripEmailDomain)
filteredMachines := filterMachinesByUser(machines, user)
filteredMachines = excludeCorrectlyTaggedNodes(pol, filteredMachines, user, stripEmailDomain)

for _, n := range nodes {
ips = append(ips, n.IPAddresses.ToStringSlice()...)
// shortcurcuit if we have no machines to get ips from.
if len(filteredMachines) == 0 {
return nil, nil //nolint
}

return ips
for _, machine := range filteredMachines {
machine.IPAddresses.AppendToIPSet(&build)
}

return build.IPSet()
}

func (pol *ACLPolicy) getIPsFromSingleIP(
ip netip.Addr,
machines Machines,
) ([]string, error) {
) (*netipx.IPSet, error) {
log.Trace().Str("ip", ip.String()).Msg("expandAlias got ip")

ips := []string{ip.String()}
matches := machines.FilterByIP(ip)

build := netipx.IPSetBuilder{}
build.Add(ip)

for _, machine := range matches {
ips = append(ips, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}

return lo.Uniq(ips), nil
return build.IPSet()
}

func (pol *ACLPolicy) getIPsFromIPPrefix(
prefix netip.Prefix,
machines Machines,
) ([]string, error) {
) (*netipx.IPSet, error) {
log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix")
val := []string{prefix.String()}
build := netipx.IPSetBuilder{}
build.AddPrefix(prefix)

// This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6
// addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers.
for _, machine := range machines {
for _, ip := range machine.IPAddresses {
// log.Trace().
// Msgf("checking if machine ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String())
if prefix.Contains(ip) {
val = append(val, machine.IPAddresses.ToStringSlice()...)
machine.IPAddresses.AppendToIPSet(&build)
}
}
}

return lo.Uniq(val), nil
return build.IPSet()
}

func isWildcard(str string) bool {
return str == "*"
}

func isGroup(str string) bool {
return strings.HasPrefix(str, "group:")
}

func isTag(str string) bool {
return strings.HasPrefix(str, "tag:")
}
Loading

0 comments on commit 735b185

Please sign in to comment.