diff --git a/cmd/spf-flatten/main.go b/cmd/spf-flatten/main.go index ef443f1..fea580a 100644 --- a/cmd/spf-flatten/main.go +++ b/cmd/spf-flatten/main.go @@ -74,14 +74,14 @@ func main() { /// Flatten SPF record for input domain r := spf.NewRootSPF(inputs.rootDomain, spf.NetLookup{}, inputs.keep) - if err = r.FlattenSPF(r.RootDomain, inputs.initialSPF); err != nil { + if err = r.FlattenSPF(r.RootDomain, inputs.initialSPF, r.TraceTree.Root()); err != nil { slog.Error("Could not flatten SPF record for initial domain", "error", err) os.Exit(1) } flatSPF := r.WriteFlatSPF() // Output flattened SPF record - slog.Info("Successfully flattened SPF record for initial domain", "flattened_record", flatSPF) + slog.Info("Successfully flattened SPF record for initial domain", "flattened_record", flatSPF, "num_dns_lookups", r.LookupCount) if inputs.warn { // Compare flattened SPF to SPF record currently set for r.RootDomain @@ -96,7 +96,15 @@ func main() { return } slog.Warn("Flattened SPF record differs from initial SPF record", "removed_from_initial", inCurrent, "added_in_flattened", inFlat) + traceMap := spf.TraceChanges(inCurrent+" "+inFlat, r.TraceTree) + for k, v := range traceMap { + slog.Warn("Trace source of changes", "change_source", k, "changes", v) + } + } + if r.LookupCount > 10 { + slog.Error("Final SPF record requires more than 10 DNS lookups") } + if inputs.dryrun { slog.Info("Dryrun complete") return diff --git a/internal/spf/flatten.go b/internal/spf/flatten.go index c06a7a3..201df23 100644 --- a/internal/spf/flatten.go +++ b/internal/spf/flatten.go @@ -38,19 +38,18 @@ func writeIPMech(ip net.IP, prefix string) string { type RootSPF struct { RootDomain string AllMechanism string - MapKeep map[string]bool + Keeps string MapIPs map[string]bool MapNonflat map[string]bool LookupIF Lookup + LookupCount int + TraceTree Tree } func NewRootSPF(rootDomain string, lookupIF Lookup, keep string) RootSPF { - mapKeep := make(map[string]bool) - for _, mechanism := range strings.Fields(keep) { - mapKeep[mechanism] = true - } - return RootSPF{RootDomain: rootDomain, LookupIF: lookupIF, MapKeep: mapKeep, - MapIPs: map[string]bool{}, MapNonflat: map[string]bool{}} + return RootSPF{RootDomain: rootDomain, Keeps: keep, MapIPs: map[string]bool{}, MapNonflat: map[string]bool{}, + LookupIF: lookupIF, LookupCount: 0, TraceTree: Tree{root: &node{name: rootDomain}}} + } var allInRecordRegex = regexp.MustCompile(`^.* (\+|-|~|\?)?all$`) @@ -59,7 +58,7 @@ var modifierRegex = regexp.MustCompile(`^(\+|-|~|\?)$`) // Lookup or check SPF record for domain, then parse each mechanism. // This runs recursively until every mechanism is added either to // r.AllMechanism, r.MapIPs, or r.MapNonflat (or ignored) -func (r *RootSPF) FlattenSPF(domain, spfRecord string) error { +func (r *RootSPF) FlattenSPF(domain, spfRecord string, parent Node) error { slog.Debug("--- Flattening domain ---", "domain", domain) if spfRecord == "" { record, err := GetDomainSPFRecord(domain, r.LookupIF) @@ -75,11 +74,6 @@ func (r *RootSPF) FlattenSPF(domain, spfRecord string) error { } containsAll := allInRecordRegex.MatchString(spfRecord) for _, mechanism := range strings.Fields(spfRecord)[1:] { - // If mechanism is in string of "keep" mechanisms, add to MapNonflat and don't parse - if _, ok := r.MapKeep[mechanism]; ok { - r.MapNonflat[mechanism] = true - continue - } // If not `all`, skip mechanism if fail modifier (- or ~) and ignore modifier otherwise if modifierRegex.MatchString(mechanism[:1]) && !allRegex.MatchString(mechanism) { if mechanism[:1] == "-" || mechanism[:1] == "~" { @@ -93,7 +87,7 @@ func (r *RootSPF) FlattenSPF(domain, spfRecord string) error { continue } // Parse mechanism - err := r.ParseMechanism(strings.TrimSpace(mechanism), domain) + err := r.ParseMechanism(strings.TrimSpace(mechanism), domain, parent) if err != nil { return fmt.Errorf("could not flatten SPF record for %s: %s\n", domain, err) } @@ -117,8 +111,10 @@ var nonflatRegex = regexp.MustCompile(`^(ptr:|exists:|exp=).*$`) var includeOrRedirectRegex = regexp.MustCompile(`^(include:|redirect=).*$`) // Parse the given mechanism and dispatch it accordingly -func (r *RootSPF) ParseMechanism(mechanism, domain string) error { +func (r *RootSPF) ParseMechanism(mechanism, domain string, parent Node) error { lastSlashIndex := strings.LastIndex(mechanism, "/") + childNode := &node{name: mechanism, parent: parent} + parent.AddChild(childNode) switch { // Copy `all` mechanism if set by ROOT_DOMAIN case allRegex.MatchString(mechanism): @@ -132,21 +128,21 @@ func (r *RootSPF) ParseMechanism(mechanism, domain string) error { r.MapIPs[mechanism] = true // Convert A/AAAA and MX records, then add to r.MapIPs case mechanism == "a": // a - return r.ConvertDomainToIP(domain, "") + return r.ConvertDomainToIP(domain, "", childNode) case aPrefixRegex.MatchString(mechanism): // a/ - return r.ConvertDomainToIP(domain, mechanism[1:]) + return r.ConvertDomainToIP(domain, mechanism[1:], childNode) case aDomainPrefixRegex.MatchString(mechanism): // a:/ - return r.ConvertDomainToIP(mechanism[2:lastSlashIndex], mechanism[lastSlashIndex:]) + return r.ConvertDomainToIP(mechanism[2:lastSlashIndex], mechanism[lastSlashIndex:], childNode) case aDomainRegex.MatchString(mechanism): // a: - return r.ConvertDomainToIP(mechanism[2:], "") + return r.ConvertDomainToIP(mechanism[2:], "", childNode) case mechanism == "mx": // mx - return r.ConvertMxToIP(domain, "") + return r.ConvertMxToIP(domain, "", childNode) case mxPrefixRegex.MatchString(mechanism): // mx/ - return r.ConvertMxToIP(domain, mechanism[2:]) + return r.ConvertMxToIP(domain, mechanism[2:], childNode) case mxDomainPrefixRegex.MatchString(mechanism): // mx:/ - return r.ConvertMxToIP(mechanism[3:lastSlashIndex], mechanism[lastSlashIndex:]) + return r.ConvertMxToIP(mechanism[3:lastSlashIndex], mechanism[lastSlashIndex:], childNode) case mxDomainRegex.MatchString(mechanism): // mx: - return r.ConvertMxToIP(mechanism[3:], "") + return r.ConvertMxToIP(mechanism[3:], "", childNode) // Add ptr, exists, and exp mechanisms to r.MapNonflat case mechanism == "ptr": slog.Debug("Adding nonflat mechanism", "mechanism", mechanism+":"+domain) @@ -156,7 +152,7 @@ func (r *RootSPF) ParseMechanism(mechanism, domain string) error { r.MapNonflat[mechanism] = true // Recursive call to FlattenSPF on `include` and `redirect` mechanism case includeOrRedirectRegex.MatchString(mechanism): - return r.FlattenSPF(mechanism[strings.IndexAny(mechanism, ":=")+1:], "") + return r.FlattenSPF(mechanism[strings.IndexAny(mechanism, ":=")+1:], "", childNode) // Return error if no match default: return fmt.Errorf("received unexpected SPF mechanism or syntax: '%s'", mechanism) @@ -165,13 +161,15 @@ func (r *RootSPF) ParseMechanism(mechanism, domain string) error { } // Convert A/AAAA records to IPs and add them to r.MapIPs -func (r *RootSPF) ConvertDomainToIP(domain, prefixLength string) error { +func (r *RootSPF) ConvertDomainToIP(domain, prefixLength string, parent Node) error { slog.Debug("Looking up IP records for domain", "domain", domain) ips, err := r.LookupIF.LookupIP(domain) if err != nil { return fmt.Errorf("could not lookup IPs for %s: %s\n", domain, err) } for _, ip := range ips { + childNode := &node{name: writeIPMech(ip, prefixLength), parent: parent} + parent.AddChild(childNode) slog.Debug("Adding IP mechanism", "mechanism", writeIPMech(ip, prefixLength)) r.MapIPs[writeIPMech(ip, prefixLength)] = true } @@ -179,23 +177,45 @@ func (r *RootSPF) ConvertDomainToIP(domain, prefixLength string) error { } // Convert MX records to domains then to IPs and add them to r.MapIPs -func (r *RootSPF) ConvertMxToIP(domain, prefixLength string) error { +func (r *RootSPF) ConvertMxToIP(domain, prefixLength string, parent Node) error { slog.Debug("Looking up MX records for domain", "domain", domain) mxs, err := r.LookupIF.LookupMX(domain) if err != nil { return fmt.Errorf("could not lookup MX records for %s: %s\n", domain, err) } for _, mx := range mxs { + childNode := &node{name: mx.Host, parent: parent} + parent.AddChild(childNode) slog.Debug("Found MX record for domain", "mx_record", mx.Host) - if err := r.ConvertDomainToIP(mx.Host, prefixLength); err != nil { + if err := r.ConvertDomainToIP(mx.Host, prefixLength, childNode); err != nil { return fmt.Errorf("could not lookup IPs for MX record `%s`: %s\n", mx.Host, err) } } return nil } +// Remove all mechanisms flattened by "keeps", add "keeps" to r.MapNonflat +// and count DNS lookups required by final SPF record +func (r *RootSPF) UnflattenKeeps() { + r.LookupCount += len(r.MapNonflat) + for _, keep := range strings.Fields(r.Keeps) { + if keepNode := r.TraceTree.FindNode(keep); keepNode != nil { + keepSubtree := r.TraceTree.GetSubtree(keepNode, []string{}) + for _, node := range keepSubtree { + if !strings.HasPrefix(node, "ip") && !strings.HasSuffix(node, "all") { + r.LookupCount += 1 + } + delete(r.MapIPs, node) + delete(r.MapNonflat, node) + } + r.MapNonflat[keep] = true + } + } +} + // Flatten and write new SPF record for root domain by compiling r.AllMechanism, r.MapIPs, and r.MapNonflat func (r *RootSPF) WriteFlatSPF() string { + r.UnflattenKeeps() flatSPF := "v=spf1" for ip := range r.MapIPs { flatSPF += " " + ip diff --git a/internal/spf/flatten_test.go b/internal/spf/flatten_test.go index d12c589..ed6b438 100644 --- a/internal/spf/flatten_test.go +++ b/internal/spf/flatten_test.go @@ -63,12 +63,12 @@ func (r *RootSPF) compareExpected(err error, expAll string, expIPs, expNF []stri func TestParseMechanismAll(t *testing.T) { r := NewRootSPF("myrootdomain", mockLookup{}, "") // Test `all`` mechanism set if domain is root - err := r.ParseMechanism("~all", "myrootdomain") + err := r.ParseMechanism("~all", "myrootdomain", r.TraceTree.Root()) if err = r.compareExpected(err, " ~all", []string{}, []string{}); err != nil { t.Fatal(err) } // Test `all`` mechanism is ignored if domain is NOT root - err = r.ParseMechanism("-all", "NOTmyrootdomain") + err = r.ParseMechanism("-all", "NOTmyrootdomain", r.TraceTree.Root()) if err = r.compareExpected(err, " ~all", []string{}, []string{}); err != nil { t.Fatal(err) } @@ -79,7 +79,7 @@ func TestParseMechanismIP(t *testing.T) { // Test ip mechanisms of the form `ip4:`, `ip4:/, `ip6:`, and `ip6:/` ipMechs := []string{"ip4:abcd", "ip4:8.8.8.8", "ip6:efgh/36", "ip6:2001:0db8:85a3:0000:0000:8a2e:0370:7334", "ip6:11:22::33/128"} for _, mech := range ipMechs { - err := r.ParseMechanism(mech, "") + err := r.ParseMechanism(mech, "", r.TraceTree.Root()) if err = r.compareExpected(err, "", []string{mech}, []string{}); err != nil { t.Fatal(err) } @@ -101,7 +101,7 @@ func TestParseMechanismA(t *testing.T) { for _, ip := range ipLookup[testCase[2]] { expIPs = append(expIPs, writeIPMech(ip, testCase[1])) } - err := r.ParseMechanism(testCase[0], testCase[3]) + err := r.ParseMechanism(testCase[0], testCase[3], r.TraceTree.Root()) if err = r.compareExpected(err, "", expIPs, []string{}); err != nil { t.Fatal(err) } @@ -125,7 +125,7 @@ func TestParseMechanismMX(t *testing.T) { expIPs = append(expIPs, writeIPMech(ip, testCase[1])) } } - err := r.ParseMechanism(testCase[0], testCase[3]) + err := r.ParseMechanism(testCase[0], testCase[3], r.TraceTree.Root()) if err = r.compareExpected(err, "", expIPs, []string{}); err != nil { t.Fatal(err) } @@ -135,14 +135,14 @@ func TestParseMechanismMX(t *testing.T) { func TestParseMechanismNonFlat(t *testing.T) { r := NewRootSPF("", mockLookup{}, "") // Test ptr mechanism of the form `ptr` - err := r.ParseMechanism("ptr", "domain") + err := r.ParseMechanism("ptr", "domain", r.TraceTree.Root()) if err = r.compareExpected(err, "", []string{}, []string{"ptr:domain"}); err != nil { t.Fatal(err) } // Test nonflat mechanisms of the form `ptr:`, ``, and `exp=` nfMechs := []string{"ptr:example.com", "exists:yourdomain", "exp=explain.example.com"} for _, nfMech := range nfMechs { - err := r.ParseMechanism(nfMech, "") + err := r.ParseMechanism(nfMech, "", r.TraceTree.Root()) if err = r.compareExpected(err, "", []string{}, []string{nfMech}); err != nil { t.Fatal(err) } @@ -157,7 +157,7 @@ func TestParseMechanismInclude(t *testing.T) { for _, ip := range ipLookup[includeDomain] { expIPs = append(expIPs, writeIPMech(ip, "")) } - err := r.ParseMechanism("include:"+includeDomain, "notmydomain") + err := r.ParseMechanism("include:"+includeDomain, "notmydomain", r.TraceTree.Root()) if err = r.compareExpected(err, "", expIPs, []string{}); err != nil { t.Fatal(err) } @@ -167,7 +167,7 @@ func TestParseMechanismRedirect(t *testing.T) { r := NewRootSPF("", mockLookup{}, "") // Test mechanism of the form `redirect=` redirectDomain := "test.com" // SPF record is just ip4:10.10.10.10 - err := r.ParseMechanism("redirect="+redirectDomain, "notmydomain") + err := r.ParseMechanism("redirect="+redirectDomain, "notmydomain", r.TraceTree.Root()) if err = r.compareExpected(err, "", []string{"ip4:10.10.10.10"}, []string{}); err != nil { t.Fatal(err) } @@ -178,7 +178,7 @@ func TestParseMechanismFails(t *testing.T) { // Test that parseMechanism fails on unexpected mechanism or syntax error noMatchRegex := regexp.MustCompile(`^received unexpected SPF mechanism or syntax.*$`) for _, wrongMech := range []string{"redirect:domain", "include=anotherdomain", "ip:0.0.0.0", "1.1.1.1", "", "ip6", "exp:explanation", "notMechanism:hello"} { - err := r.ParseMechanism(wrongMech, "") + err := r.ParseMechanism(wrongMech, "", r.TraceTree.Root()) if !noMatchRegex.MatchString(err.Error()) { t.Fatalf("Expected `received unexpected SPF mechanism or syntax` error, got `%s` instead", err) } @@ -200,7 +200,7 @@ func TestFlattenSPF(t *testing.T) { expIPs = append(expIPs, writeIPMech(ip, "")) } } - err := r.FlattenSPF(domain, spf) + err := r.FlattenSPF(domain, spf, r.TraceTree.Root()) if err = r.compareExpected(err, " -all", expIPs, expNFs); err != nil { t.Fatal(err) } @@ -218,7 +218,8 @@ func TestNoFlattenKeeps(t *testing.T) { expIPs = append(expIPs, writeIPMech(ip, "")) } } - err := r.FlattenSPF(domain, spf) + err := r.FlattenSPF(domain, spf, r.TraceTree.Root()) + r.UnflattenKeeps() if err = r.compareExpected(err, " -all", expIPs, expNFs); err != nil { t.Fatal(err) } @@ -241,7 +242,7 @@ func TestFlattenRedirects(t *testing.T) { for _, ip := range ipLookup["mydomain"] { expIPs = append(expIPs, writeIPMech(ip, "")) } - err := r.FlattenSPF(domain, spf) + err := r.FlattenSPF(domain, spf, r.TraceTree.Root()) if err = r.compareExpected(err, "", expIPs, []string{}); err != nil { t.Fatal(err) } @@ -250,7 +251,7 @@ func TestFlattenRedirects(t *testing.T) { } // Test that redirects are ignored if SPF record includes `all` mechanism r.MapIPs = map[string]bool{} - err = r.FlattenSPF(domain, spf+" ~all") + err = r.FlattenSPF(domain, spf+" ~all", r.TraceTree.Root()) if err = r.compareExpected(err, "", []string{"ip4:9.9.9.9"}, []string{}); err != nil { t.Fatal(err) } @@ -270,7 +271,7 @@ func TestFlattenModifiers(t *testing.T) { expIPs = append(expIPs, writeIPMech(ip, "")) } } - err := r.FlattenSPF(domain, spf) + err := r.FlattenSPF(domain, spf, r.TraceTree.Root()) if err = r.compareExpected(err, " -all", expIPs, []string{}); err != nil { t.Fatal(err) } diff --git a/internal/spf/record.go b/internal/spf/record.go index 20bc87c..eb53c75 100644 --- a/internal/spf/record.go +++ b/internal/spf/record.go @@ -56,6 +56,26 @@ func CheckSPFRecord(domain, spfRecord string, lookupIF Lookup) error { return fmt.Errorf("SPF record for %s did not match expected format. Got '%s'", domain, spfRecord) } +// For each mechanism in "mechanisms", lookup trace of changes/flattening +// and store in map with traces as keys and group of mechanisms with shared trace +// as values, then return map. E.g. map["include:example.com => a"] = "ip4:0.0.0.0 ip4:1.2.3.4" +func TraceChanges(mechanisms string, traceTree Tree) map[string]string { + mapTraces := map[string]string{} + for _, mech := range strings.Fields(mechanisms) { + node := traceTree.FindNode(mech) + if node == nil { + fmt.Printf("couldn't find node %s\n", mech) + } + trace := strings.Join(traceTree.GetAncestors(node), " => ") + if _, ok := mapTraces[trace]; ok { + mapTraces[trace] += " " + mech + } else { + mapTraces[trace] = mech + } + } + return mapTraces +} + // Compare intial and flattened SPF records by checking that they both // have the same entries regardless of order. Return any different entries. func CompareRecords(startSPF, endSPF string) (bool, string, string) { diff --git a/internal/spf/tree.go b/internal/spf/tree.go new file mode 100644 index 0000000..41d537a --- /dev/null +++ b/internal/spf/tree.go @@ -0,0 +1,106 @@ +package spf + +import ( + "fmt" + "strings" +) + +type Tree struct { + root Node +} + +func (t *Tree) Root() Node { + return t.root +} + +// Return list of names of nodes in subtree of input node +func (t *Tree) GetSubtree(node Node, listNodes []string) []string { + listNodes = append(listNodes, node.GetName()) + for _, child := range node.GetChildren() { + listNodes = append(listNodes, t.GetSubtree(child, []string{})...) + } + return listNodes +} + +// Return list of names of nodes of ancestors of input node +func (t *Tree) GetAncestors(n Node) []string { + listAncestors := []string{} + node := n + for node.GetParent() != nil { + listAncestors = append([]string{node.GetParent().GetName()}, listAncestors...) + node = node.GetParent() + } + return listAncestors +} + +// Return node in tree with matching input name +func (t *Tree) FindNode(name string) Node { + return t.FindSubtreeNode(t.Root(), name) +} + +// Within subtree of input node, return node with matching input name +func (t *Tree) FindSubtreeNode(node Node, name string) Node { + if node.GetName() == name { + return node + } + for _, child := range node.GetChildren() { + temp := t.FindSubtreeNode(child, name) + if temp != nil { + return temp + } + } + return nil +} + +func (t *Tree) PrintTree() { + t.PrintSubtree(t.Root(), 0) +} + +func (t *Tree) PrintSubtree(node Node, tabs int) { + node.PrintNode(tabs) + for _, child := range node.GetChildren() { + t.PrintSubtree(child, tabs+1) + } +} + +type Node interface { + GetName() string + GetParent() Node + GetChildren() []Node + AddChild(Node) + PrintNode(tabs int) +} + +type node struct { + name string + parent Node + children []Node +} + +func (n *node) GetName() string { + return n.name +} + +func (n *node) GetParent() Node { + return n.parent +} + +func (n *node) GetChildren() []Node { + return n.children +} + +// Add input node to n's list of children +func (n *node) AddChild(child Node) { + if n.children == nil { + n.children = []Node{} + } + n.children = append(n.children, child) +} + +func (n *node) PrintNode(tabs int) { + listChildren := []string{} + for _, child := range n.children { + listChildren = append(listChildren, child.GetName()) + } + fmt.Println(strings.Repeat("\t", tabs), n.GetName(), strings.Join(listChildren, ",")) +}