Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trace tree for dns lookup count and change tracing #18

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cmd/spf-flatten/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ func main() {

/// Flatten SPF record for input domain
r := spf.NewRootSPF(inputs.rootDomain, spf.NetLookup{}, inputs.keep)
if err = r.FlattenSPF(r.RootDomain, inputs.initialSPF); err != nil {
if err = r.FlattenSPF(r.RootDomain, inputs.initialSPF, r.TraceTree.Root()); err != nil {
slog.Error("Could not flatten SPF record for initial domain", "error", err)
os.Exit(1)
}
flatSPF := r.WriteFlatSPF()

// Output flattened SPF record
slog.Info("Successfully flattened SPF record for initial domain", "flattened_record", flatSPF)
slog.Info("Successfully flattened SPF record for initial domain", "flattened_record", flatSPF, "num_dns_lookups", r.LookupCount)

if inputs.warn {
// Compare flattened SPF to SPF record currently set for r.RootDomain
Expand All @@ -96,7 +96,15 @@ func main() {
return
}
slog.Warn("Flattened SPF record differs from initial SPF record", "removed_from_initial", inCurrent, "added_in_flattened", inFlat)
traceMap := spf.TraceChanges(inCurrent+" "+inFlat, r.TraceTree)
for k, v := range traceMap {
slog.Warn("Trace source of changes", "change_source", k, "changes", v)
}
}
if r.LookupCount > 10 {
slog.Error("Final SPF record requires more than 10 DNS lookups")
}

if inputs.dryrun {
slog.Info("Dryrun complete")
return
Expand Down
74 changes: 47 additions & 27 deletions internal/spf/flatten.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,18 @@ func writeIPMech(ip net.IP, prefix string) string {
type RootSPF struct {
RootDomain string
AllMechanism string
MapKeep map[string]bool
Keeps string
MapIPs map[string]bool
MapNonflat map[string]bool
LookupIF Lookup
LookupCount int
TraceTree Tree
}

func NewRootSPF(rootDomain string, lookupIF Lookup, keep string) RootSPF {
mapKeep := make(map[string]bool)
for _, mechanism := range strings.Fields(keep) {
mapKeep[mechanism] = true
}
return RootSPF{RootDomain: rootDomain, LookupIF: lookupIF, MapKeep: mapKeep,
MapIPs: map[string]bool{}, MapNonflat: map[string]bool{}}
return RootSPF{RootDomain: rootDomain, Keeps: keep, MapIPs: map[string]bool{}, MapNonflat: map[string]bool{},
LookupIF: lookupIF, LookupCount: 0, TraceTree: Tree{root: &node{name: rootDomain}}}

}

var allInRecordRegex = regexp.MustCompile(`^.* (\+|-|~|\?)?all$`)
Expand All @@ -59,7 +58,7 @@ var modifierRegex = regexp.MustCompile(`^(\+|-|~|\?)$`)
// Lookup or check SPF record for domain, then parse each mechanism.
// This runs recursively until every mechanism is added either to
// r.AllMechanism, r.MapIPs, or r.MapNonflat (or ignored)
func (r *RootSPF) FlattenSPF(domain, spfRecord string) error {
func (r *RootSPF) FlattenSPF(domain, spfRecord string, parent Node) error {
slog.Debug("--- Flattening domain ---", "domain", domain)
if spfRecord == "" {
record, err := GetDomainSPFRecord(domain, r.LookupIF)
Expand All @@ -75,11 +74,6 @@ func (r *RootSPF) FlattenSPF(domain, spfRecord string) error {
}
containsAll := allInRecordRegex.MatchString(spfRecord)
for _, mechanism := range strings.Fields(spfRecord)[1:] {
// If mechanism is in string of "keep" mechanisms, add to MapNonflat and don't parse
if _, ok := r.MapKeep[mechanism]; ok {
r.MapNonflat[mechanism] = true
continue
}
// If not `all`, skip mechanism if fail modifier (- or ~) and ignore modifier otherwise
if modifierRegex.MatchString(mechanism[:1]) && !allRegex.MatchString(mechanism) {
if mechanism[:1] == "-" || mechanism[:1] == "~" {
Expand All @@ -93,7 +87,7 @@ func (r *RootSPF) FlattenSPF(domain, spfRecord string) error {
continue
}
// Parse mechanism
err := r.ParseMechanism(strings.TrimSpace(mechanism), domain)
err := r.ParseMechanism(strings.TrimSpace(mechanism), domain, parent)
if err != nil {
return fmt.Errorf("could not flatten SPF record for %s: %s\n", domain, err)
}
Expand All @@ -117,8 +111,10 @@ var nonflatRegex = regexp.MustCompile(`^(ptr:|exists:|exp=).*$`)
var includeOrRedirectRegex = regexp.MustCompile(`^(include:|redirect=).*$`)

// Parse the given mechanism and dispatch it accordingly
func (r *RootSPF) ParseMechanism(mechanism, domain string) error {
func (r *RootSPF) ParseMechanism(mechanism, domain string, parent Node) error {
lastSlashIndex := strings.LastIndex(mechanism, "/")
childNode := &node{name: mechanism, parent: parent}
parent.AddChild(childNode)
switch {
// Copy `all` mechanism if set by ROOT_DOMAIN
case allRegex.MatchString(mechanism):
Expand All @@ -132,21 +128,21 @@ func (r *RootSPF) ParseMechanism(mechanism, domain string) error {
r.MapIPs[mechanism] = true
// Convert A/AAAA and MX records, then add to r.MapIPs
case mechanism == "a": // a
return r.ConvertDomainToIP(domain, "")
return r.ConvertDomainToIP(domain, "", childNode)
case aPrefixRegex.MatchString(mechanism): // a/<prefix-length>
return r.ConvertDomainToIP(domain, mechanism[1:])
return r.ConvertDomainToIP(domain, mechanism[1:], childNode)
case aDomainPrefixRegex.MatchString(mechanism): // a:<domain>/<prefix-length>
return r.ConvertDomainToIP(mechanism[2:lastSlashIndex], mechanism[lastSlashIndex:])
return r.ConvertDomainToIP(mechanism[2:lastSlashIndex], mechanism[lastSlashIndex:], childNode)
case aDomainRegex.MatchString(mechanism): // a:<domain>
return r.ConvertDomainToIP(mechanism[2:], "")
return r.ConvertDomainToIP(mechanism[2:], "", childNode)
case mechanism == "mx": // mx
return r.ConvertMxToIP(domain, "")
return r.ConvertMxToIP(domain, "", childNode)
case mxPrefixRegex.MatchString(mechanism): // mx/<prefix-length>
return r.ConvertMxToIP(domain, mechanism[2:])
return r.ConvertMxToIP(domain, mechanism[2:], childNode)
case mxDomainPrefixRegex.MatchString(mechanism): // mx:<domain>/<prefix-length>
return r.ConvertMxToIP(mechanism[3:lastSlashIndex], mechanism[lastSlashIndex:])
return r.ConvertMxToIP(mechanism[3:lastSlashIndex], mechanism[lastSlashIndex:], childNode)
case mxDomainRegex.MatchString(mechanism): // mx:<domain>
return r.ConvertMxToIP(mechanism[3:], "")
return r.ConvertMxToIP(mechanism[3:], "", childNode)
// Add ptr, exists, and exp mechanisms to r.MapNonflat
case mechanism == "ptr":
slog.Debug("Adding nonflat mechanism", "mechanism", mechanism+":"+domain)
Expand All @@ -156,7 +152,7 @@ func (r *RootSPF) ParseMechanism(mechanism, domain string) error {
r.MapNonflat[mechanism] = true
// Recursive call to FlattenSPF on `include` and `redirect` mechanism
case includeOrRedirectRegex.MatchString(mechanism):
return r.FlattenSPF(mechanism[strings.IndexAny(mechanism, ":=")+1:], "")
return r.FlattenSPF(mechanism[strings.IndexAny(mechanism, ":=")+1:], "", childNode)
// Return error if no match
default:
return fmt.Errorf("received unexpected SPF mechanism or syntax: '%s'", mechanism)
Expand All @@ -165,37 +161,61 @@ func (r *RootSPF) ParseMechanism(mechanism, domain string) error {
}

// Convert A/AAAA records to IPs and add them to r.MapIPs
func (r *RootSPF) ConvertDomainToIP(domain, prefixLength string) error {
func (r *RootSPF) ConvertDomainToIP(domain, prefixLength string, parent Node) error {
slog.Debug("Looking up IP records for domain", "domain", domain)
ips, err := r.LookupIF.LookupIP(domain)
if err != nil {
return fmt.Errorf("could not lookup IPs for %s: %s\n", domain, err)
}
for _, ip := range ips {
childNode := &node{name: writeIPMech(ip, prefixLength), parent: parent}
parent.AddChild(childNode)
slog.Debug("Adding IP mechanism", "mechanism", writeIPMech(ip, prefixLength))
r.MapIPs[writeIPMech(ip, prefixLength)] = true
}
return nil
}

// Convert MX records to domains then to IPs and add them to r.MapIPs
func (r *RootSPF) ConvertMxToIP(domain, prefixLength string) error {
func (r *RootSPF) ConvertMxToIP(domain, prefixLength string, parent Node) error {
slog.Debug("Looking up MX records for domain", "domain", domain)
mxs, err := r.LookupIF.LookupMX(domain)
if err != nil {
return fmt.Errorf("could not lookup MX records for %s: %s\n", domain, err)
}
for _, mx := range mxs {
childNode := &node{name: mx.Host, parent: parent}
parent.AddChild(childNode)
slog.Debug("Found MX record for domain", "mx_record", mx.Host)
if err := r.ConvertDomainToIP(mx.Host, prefixLength); err != nil {
if err := r.ConvertDomainToIP(mx.Host, prefixLength, childNode); err != nil {
return fmt.Errorf("could not lookup IPs for MX record `%s`: %s\n", mx.Host, err)
}
}
return nil
}

// Remove all mechanisms flattened by "keeps", add "keeps" to r.MapNonflat
// and count DNS lookups required by final SPF record
func (r *RootSPF) UnflattenKeeps() {
r.LookupCount += len(r.MapNonflat)
for _, keep := range strings.Fields(r.Keeps) {
if keepNode := r.TraceTree.FindNode(keep); keepNode != nil {
keepSubtree := r.TraceTree.GetSubtree(keepNode, []string{})
for _, node := range keepSubtree {
if !strings.HasPrefix(node, "ip") && !strings.HasSuffix(node, "all") {
r.LookupCount += 1
}
delete(r.MapIPs, node)
delete(r.MapNonflat, node)
}
r.MapNonflat[keep] = true
}
}
}

// Flatten and write new SPF record for root domain by compiling r.AllMechanism, r.MapIPs, and r.MapNonflat
func (r *RootSPF) WriteFlatSPF() string {
r.UnflattenKeeps()
flatSPF := "v=spf1"
for ip := range r.MapIPs {
flatSPF += " " + ip
Expand Down
31 changes: 16 additions & 15 deletions internal/spf/flatten_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ func (r *RootSPF) compareExpected(err error, expAll string, expIPs, expNF []stri
func TestParseMechanismAll(t *testing.T) {
r := NewRootSPF("myrootdomain", mockLookup{}, "")
// Test `all`` mechanism set if domain is root
err := r.ParseMechanism("~all", "myrootdomain")
err := r.ParseMechanism("~all", "myrootdomain", r.TraceTree.Root())
if err = r.compareExpected(err, " ~all", []string{}, []string{}); err != nil {
t.Fatal(err)
}
// Test `all`` mechanism is ignored if domain is NOT root
err = r.ParseMechanism("-all", "NOTmyrootdomain")
err = r.ParseMechanism("-all", "NOTmyrootdomain", r.TraceTree.Root())
if err = r.compareExpected(err, " ~all", []string{}, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -79,7 +79,7 @@ func TestParseMechanismIP(t *testing.T) {
// Test ip mechanisms of the form `ip4:<ipaddr>`, `ip4:<ipaddr>/<prefix-length>, `ip6:<ipaddr>`, and `ip6:<ipaddr>/<prefix-length>`
ipMechs := []string{"ip4:abcd", "ip4:8.8.8.8", "ip6:efgh/36", "ip6:2001:0db8:85a3:0000:0000:8a2e:0370:7334", "ip6:11:22::33/128"}
for _, mech := range ipMechs {
err := r.ParseMechanism(mech, "")
err := r.ParseMechanism(mech, "", r.TraceTree.Root())
if err = r.compareExpected(err, "", []string{mech}, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -101,7 +101,7 @@ func TestParseMechanismA(t *testing.T) {
for _, ip := range ipLookup[testCase[2]] {
expIPs = append(expIPs, writeIPMech(ip, testCase[1]))
}
err := r.ParseMechanism(testCase[0], testCase[3])
err := r.ParseMechanism(testCase[0], testCase[3], r.TraceTree.Root())
if err = r.compareExpected(err, "", expIPs, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -125,7 +125,7 @@ func TestParseMechanismMX(t *testing.T) {
expIPs = append(expIPs, writeIPMech(ip, testCase[1]))
}
}
err := r.ParseMechanism(testCase[0], testCase[3])
err := r.ParseMechanism(testCase[0], testCase[3], r.TraceTree.Root())
if err = r.compareExpected(err, "", expIPs, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -135,14 +135,14 @@ func TestParseMechanismMX(t *testing.T) {
func TestParseMechanismNonFlat(t *testing.T) {
r := NewRootSPF("", mockLookup{}, "")
// Test ptr mechanism of the form `ptr`
err := r.ParseMechanism("ptr", "domain")
err := r.ParseMechanism("ptr", "domain", r.TraceTree.Root())
if err = r.compareExpected(err, "", []string{}, []string{"ptr:domain"}); err != nil {
t.Fatal(err)
}
// Test nonflat mechanisms of the form `ptr:<domain>`, `<exists:<domain>`, and `exp=<domain>`
nfMechs := []string{"ptr:example.com", "exists:yourdomain", "exp=explain.example.com"}
for _, nfMech := range nfMechs {
err := r.ParseMechanism(nfMech, "")
err := r.ParseMechanism(nfMech, "", r.TraceTree.Root())
if err = r.compareExpected(err, "", []string{}, []string{nfMech}); err != nil {
t.Fatal(err)
}
Expand All @@ -157,7 +157,7 @@ func TestParseMechanismInclude(t *testing.T) {
for _, ip := range ipLookup[includeDomain] {
expIPs = append(expIPs, writeIPMech(ip, ""))
}
err := r.ParseMechanism("include:"+includeDomain, "notmydomain")
err := r.ParseMechanism("include:"+includeDomain, "notmydomain", r.TraceTree.Root())
if err = r.compareExpected(err, "", expIPs, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -167,7 +167,7 @@ func TestParseMechanismRedirect(t *testing.T) {
r := NewRootSPF("", mockLookup{}, "")
// Test mechanism of the form `redirect=<domain>`
redirectDomain := "test.com" // SPF record is just ip4:10.10.10.10
err := r.ParseMechanism("redirect="+redirectDomain, "notmydomain")
err := r.ParseMechanism("redirect="+redirectDomain, "notmydomain", r.TraceTree.Root())
if err = r.compareExpected(err, "", []string{"ip4:10.10.10.10"}, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -178,7 +178,7 @@ func TestParseMechanismFails(t *testing.T) {
// Test that parseMechanism fails on unexpected mechanism or syntax error
noMatchRegex := regexp.MustCompile(`^received unexpected SPF mechanism or syntax.*$`)
for _, wrongMech := range []string{"redirect:domain", "include=anotherdomain", "ip:0.0.0.0", "1.1.1.1", "", "ip6", "exp:explanation", "notMechanism:hello"} {
err := r.ParseMechanism(wrongMech, "")
err := r.ParseMechanism(wrongMech, "", r.TraceTree.Root())
if !noMatchRegex.MatchString(err.Error()) {
t.Fatalf("Expected `received unexpected SPF mechanism or syntax` error, got `%s` instead", err)
}
Expand All @@ -200,7 +200,7 @@ func TestFlattenSPF(t *testing.T) {
expIPs = append(expIPs, writeIPMech(ip, ""))
}
}
err := r.FlattenSPF(domain, spf)
err := r.FlattenSPF(domain, spf, r.TraceTree.Root())
if err = r.compareExpected(err, " -all", expIPs, expNFs); err != nil {
t.Fatal(err)
}
Expand All @@ -218,7 +218,8 @@ func TestNoFlattenKeeps(t *testing.T) {
expIPs = append(expIPs, writeIPMech(ip, ""))
}
}
err := r.FlattenSPF(domain, spf)
err := r.FlattenSPF(domain, spf, r.TraceTree.Root())
r.UnflattenKeeps()
if err = r.compareExpected(err, " -all", expIPs, expNFs); err != nil {
t.Fatal(err)
}
Expand All @@ -241,7 +242,7 @@ func TestFlattenRedirects(t *testing.T) {
for _, ip := range ipLookup["mydomain"] {
expIPs = append(expIPs, writeIPMech(ip, ""))
}
err := r.FlattenSPF(domain, spf)
err := r.FlattenSPF(domain, spf, r.TraceTree.Root())
if err = r.compareExpected(err, "", expIPs, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -250,7 +251,7 @@ func TestFlattenRedirects(t *testing.T) {
}
// Test that redirects are ignored if SPF record includes `all` mechanism
r.MapIPs = map[string]bool{}
err = r.FlattenSPF(domain, spf+" ~all")
err = r.FlattenSPF(domain, spf+" ~all", r.TraceTree.Root())
if err = r.compareExpected(err, "", []string{"ip4:9.9.9.9"}, []string{}); err != nil {
t.Fatal(err)
}
Expand All @@ -270,7 +271,7 @@ func TestFlattenModifiers(t *testing.T) {
expIPs = append(expIPs, writeIPMech(ip, ""))
}
}
err := r.FlattenSPF(domain, spf)
err := r.FlattenSPF(domain, spf, r.TraceTree.Root())
if err = r.compareExpected(err, " -all", expIPs, []string{}); err != nil {
t.Fatal(err)
}
Expand Down
20 changes: 20 additions & 0 deletions internal/spf/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ func CheckSPFRecord(domain, spfRecord string, lookupIF Lookup) error {
return fmt.Errorf("SPF record for %s did not match expected format. Got '%s'", domain, spfRecord)
}

// For each mechanism in "mechanisms", lookup trace of changes/flattening
// and store in map with traces as keys and group of mechanisms with shared trace
// as values, then return map. E.g. map["include:example.com => a"] = "ip4:0.0.0.0 ip4:1.2.3.4"
func TraceChanges(mechanisms string, traceTree Tree) map[string]string {
mapTraces := map[string]string{}
for _, mech := range strings.Fields(mechanisms) {
node := traceTree.FindNode(mech)
if node == nil {
fmt.Printf("couldn't find node %s\n", mech)
}
trace := strings.Join(traceTree.GetAncestors(node), " => ")
if _, ok := mapTraces[trace]; ok {
mapTraces[trace] += " " + mech
} else {
mapTraces[trace] = mech
}
}
return mapTraces
}

// Compare intial and flattened SPF records by checking that they both
// have the same entries regardless of order. Return any different entries.
func CompareRecords(startSPF, endSPF string) (bool, string, string) {
Expand Down
Loading
Loading