diff --git a/cmd/root.go b/cmd/root.go index a41894b..698287b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -6,24 +6,16 @@ import ( "os" "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" - "github.com/rs/zerolog/log" + "github.com/jpts/coredns-enum/internal/types" + "github.com/jpts/coredns-enum/internal/util" + "github.com/jpts/coredns-enum/pkg/dnsclient" + "github.com/jpts/coredns-enum/pkg/scanners" ) -type cliOpts struct { - loglevel int - maxWorkers int - cidrRange string - nameserver string - nameport int - timeout float32 - mode string - zone string - proto string -} - -var opts cliOpts +var opts types.CliOpts var rootCmd = &cobra.Command{ Use: "coredns-enum", @@ -34,46 +26,46 @@ var rootCmd = &cobra.Command{ var err error log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) - lvl, err := zerolog.ParseLevel(fmt.Sprint(opts.loglevel)) + lvl, err := zerolog.ParseLevel(fmt.Sprint(opts.LogLevel)) if err != nil { return errors.New("Error setting up logging") } zerolog.SetGlobalLevel(lvl) - if opts.proto != "udp" && opts.proto != "tcp" && opts.proto != "auto" { + if opts.Proto != "udp" && opts.Proto != "tcp" && opts.Proto != "auto" { log.Error().Msg("Invalid protocol") } - initDNS() + dclient := dnsclient.InitDNS(&opts) - if opts.nameserver == "" { - opts.nameserver, opts.nameport, err = getNSFromSystem() + if opts.Nameserver == "" { + opts.Nameserver, opts.Nameport, err = dclient.GetNSFromSystem() if err != nil { return err } - log.Info().Msgf("Detected nameserver as %s:%d", opts.nameserver, opts.nameport) + log.Info().Msgf("Detected nameserver as %s:%d", opts.Nameserver, opts.Nameport) } - if opts.mode == MODE_AUTO { - opts.mode = detectMode() + if opts.Mode == dnsclient.MODE_AUTO { + opts.Mode = dclient.DetectMode() } - var res []*svcResult - switch opts.mode { - case MODE_BRUTEFORCE: - res, err = brute(&opts) - case MODE_WILDCARD: - res, err = wildcard(&opts) - case MODE_FAILED: + var res []*types.SvcResult + switch opts.Mode { + case dnsclient.MODE_BRUTEFORCE: + res, err = scanners.BruteScan(&opts, dclient) + case dnsclient.MODE_WILDCARD: + res, err = scanners.WildcardScan(&opts, dclient) + case dnsclient.MODE_FAILED: err = fmt.Errorf("Failed to detect mode automatically") default: - err = fmt.Errorf("Unsupported mode: %s", opts.mode) + err = fmt.Errorf("Unsupported mode: %s", opts.Mode) } if err != nil { return err } - renderResults(res) + util.RenderResults(res) return nil }, @@ -90,17 +82,17 @@ func Execute() { func init() { // global flags - rootCmd.PersistentFlags().IntVarP(&opts.loglevel, "loglevel", "v", 1, "Set loglevel (-1 => 5)") - rootCmd.PersistentFlags().StringVarP(&opts.mode, "mode", "m", "auto", "Select mode: wildcard|bruteforce|auto") - rootCmd.PersistentFlags().StringVar(&opts.zone, "zone", "cluster.local", "DNS zone") + rootCmd.PersistentFlags().IntVarP(&opts.LogLevel, "loglevel", "v", 1, "Set loglevel (-1 => 5)") + rootCmd.PersistentFlags().StringVarP(&opts.Mode, "mode", "m", "auto", "Select mode: wildcard|bruteforce|auto") + rootCmd.PersistentFlags().StringVar(&opts.Zone, "zone", "cluster.local", "DNS zone") // bruteforce - rootCmd.Flags().IntVarP(&opts.maxWorkers, "max-workers", "t", 50, "Number of 'workers' to use for concurrency") - rootCmd.Flags().StringVar(&opts.cidrRange, "cidr", "", "Range to scan in bruteforce mode") + rootCmd.Flags().IntVarP(&opts.MaxWorkers, "max-workers", "t", 50, "Number of 'workers' to use for concurrency") + rootCmd.Flags().StringVar(&opts.CidrRange, "cidr", "", "Range to scan in bruteforce mode") // nameserver - rootCmd.Flags().StringVarP(&opts.nameserver, "nsip", "n", "", "Nameserver to use (detected by default)") - rootCmd.Flags().IntVar(&opts.nameport, "nsport", 53, "Nameserver port to use (detected by default)") - rootCmd.Flags().Float32Var(&opts.timeout, "timeout", 0.5, "DNS query timeout (seconds)") - rootCmd.Flags().StringVar(&opts.proto, "protocol", "auto", "DNS protocol: udp|tcp|auto") + rootCmd.Flags().StringVarP(&opts.Nameserver, "nsip", "n", "", "Nameserver to use (detected by default)") + rootCmd.Flags().IntVar(&opts.Nameport, "nsport", 53, "Nameserver port to use (detected by default)") + rootCmd.Flags().Float32Var(&opts.Timeout, "timeout", 0.5, "DNS query timeout (seconds)") + rootCmd.Flags().StringVar(&opts.Proto, "protocol", "auto", "DNS protocol: udp|tcp|auto") } diff --git a/internal/types/cli.go b/internal/types/cli.go new file mode 100644 index 0000000..e010d13 --- /dev/null +++ b/internal/types/cli.go @@ -0,0 +1,13 @@ +package types + +type CliOpts struct { + LogLevel int + MaxWorkers int + CidrRange string + Nameserver string + Nameport int + Timeout float32 + Mode string + Zone string + Proto string +} diff --git a/internal/types/results.go b/internal/types/results.go new file mode 100644 index 0000000..aa7c179 --- /dev/null +++ b/internal/types/results.go @@ -0,0 +1,42 @@ +package types + +import ( + "fmt" + "net" + "time" + + "github.com/miekg/dns" +) + +type QueryResult struct { + Answers []dns.RR + Additional []dns.RR + Raw *dns.Msg + IP *net.IP + RTT *time.Duration +} + +type SvcResult struct { + Name string + Namespace string + IP *net.IP + Ports []*PortResult + Endpoints []*PodResult +} + +func (s *SvcResult) String() string { return fmt.Sprintf("%s/%s", s.Namespace, s.Name) } + +type PodResult struct { + Name string + Namespace string + IP *net.IP + Ports []*PortResult +} + +type PortResult struct { + Proto string + PortNo int + PortName string +} + +func (p *PortResult) String() string { return fmt.Sprintf("%d/%s", p.PortNo, p.Proto) } diff --git a/internal/util/array.go b/internal/util/array.go new file mode 100644 index 0000000..1137f20 --- /dev/null +++ b/internal/util/array.go @@ -0,0 +1,18 @@ +package util + +func Reverse(numbers []string) []string { + newNumbers := make([]string, len(numbers)) + for i, j := 0, len(numbers)-1; i <= j; i, j = i+1, j-1 { + newNumbers[i], newNumbers[j] = numbers[j], numbers[i] + } + return newNumbers +} + +func IsElement(s []string, str string) bool { + for _, v := range s { + if v == str { + return true + } + } + return false +} diff --git a/internal/util/cidr.go b/internal/util/cidr.go new file mode 100644 index 0000000..cb64ec0 --- /dev/null +++ b/internal/util/cidr.go @@ -0,0 +1,33 @@ +package util + +import ( + "fmt" + + "github.com/seancfoley/ipaddress-go/ipaddr" + "github.com/seancfoley/ipaddress-go/ipaddr/addrstrparam" +) + +func ParseIPv4CIDR(cidr string) (*ipaddr.IPAddress, error) { + pb := addrstrparam.IPAddressStringParamsBuilder{} + pb.AllowWildcardedSeparator(true) + pb.AllowIPv4(true) + pb.AllowIPv6(false) + pb.AllowMask(true) + pb.AllowPrefix(true) + pb.AllowEmpty(false) + pb.AllowSingleSegment(false) + params := pb.ToParams() + + ipastr := ipaddr.NewIPAddressStringParams(cidr, params) + + if !ipastr.IsPrefixed() { + return nil, fmt.Errorf("CIDR %s requires prefix, use /32 for a single host", cidr) + } + + subnet, err := ipastr.ToAddress() + if err != nil { + return nil, err + } + + return subnet.ToPrefixBlock(), nil +} diff --git a/cmd/printer.go b/internal/util/printer.go similarity index 56% rename from cmd/printer.go rename to internal/util/printer.go index 7bf6a8f..bf19866 100644 --- a/cmd/printer.go +++ b/internal/util/printer.go @@ -1,63 +1,15 @@ -package cmd +package util import ( "fmt" - "net" "os" "sort" - "time" - "github.com/miekg/dns" + "github.com/jpts/coredns-enum/internal/types" "github.com/olekukonko/tablewriter" ) -type queryResult struct { - answers []dns.RR - additional []dns.RR - raw *dns.Msg - ip *net.IP - rtt *time.Duration -} - -type svcResult struct { - Name string - Namespace string - IP *net.IP - Ports []*portResult - Endpoints []*podResult -} - -func (s *svcResult) String() string { return fmt.Sprintf("%s/%s", s.Namespace, s.Name) } - -type podResult struct { - Name string - Namespace string - IP *net.IP - Ports []*portResult -} - -type portResult struct { - Proto string - PortNo int - PortName string -} - -func (p *portResult) String() string { return fmt.Sprintf("%d/%s", p.PortNo, p.Proto) } - -type SortByNsName []*svcResult - -func (a SortByNsName) Len() int { return len(a) } - -func (a SortByNsName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } - -func (a SortByNsName) Less(i, j int) bool { - if a[i].Namespace == a[j].Namespace { - return a[i].Name < a[j].Name - } - return a[i].Namespace < a[j].Namespace -} - -func renderResults(res []*svcResult) { +func RenderResults(res []*types.SvcResult) { var output [][]string sort.Sort(SortByNsName(res)) diff --git a/internal/util/sortsvc.go b/internal/util/sortsvc.go new file mode 100644 index 0000000..8faf5d9 --- /dev/null +++ b/internal/util/sortsvc.go @@ -0,0 +1,16 @@ +package util + +import "github.com/jpts/coredns-enum/internal/types" + +type SortByNsName []*types.SvcResult + +func (a SortByNsName) Len() int { return len(a) } + +func (a SortByNsName) Swap(i, j int) { a[i], a[j] = a[j], a[i] } + +func (a SortByNsName) Less(i, j int) bool { + if a[i].Namespace == a[j].Namespace { + return a[i].Name < a[j].Name + } + return a[i].Namespace < a[j].Namespace +} diff --git a/cmd/dns.go b/pkg/dnsclient/dns.go similarity index 55% rename from cmd/dns.go rename to pkg/dnsclient/dns.go index 38dbfd0..f363e23 100644 --- a/cmd/dns.go +++ b/pkg/dnsclient/dns.go @@ -1,4 +1,4 @@ -package cmd +package dnsclient import ( "errors" @@ -10,29 +10,48 @@ import ( "github.com/miekg/dns" "github.com/rs/zerolog/log" + + "github.com/jpts/coredns-enum/internal/types" + "github.com/jpts/coredns-enum/internal/util" ) -var clientUDP = &dns.Client{Net: "udp"} -var clientTCP = &dns.Client{Net: "tcp"} +type DNSClient struct { + UDPClient *dns.Client + TCPClient *dns.Client + CliOpts *types.CliOpts +} -func initDNS() { - dur, _ := time.ParseDuration(fmt.Sprintf("%fs", opts.timeout)) +func InitDNS(opts *types.CliOpts) *DNSClient { + dur, _ := time.ParseDuration(fmt.Sprintf("%fs", opts.Timeout)) if dur < time.Microsecond { dur = time.Microsecond } log.Debug().Msgf("timeout configured: %s", dur) - clientUDP.Timeout = dur - clientTCP.Timeout = dur + + clientUDP := &dns.Client{ + Net: "udp", + Timeout: dur, + } + clientTCP := &dns.Client{ + Net: "tcp", + Timeout: dur, + } + + return &DNSClient{ + UDPClient: clientUDP, + TCPClient: clientTCP, + CliOpts: opts, + } } -func getNSFromSystem() (string, int, error) { +func (d *DNSClient) GetNSFromSystem() (string, int, error) { conf, err := dns.ClientConfigFromFile("/etc/resolv.conf") if err != nil { return "", 0, fmt.Errorf("Error making client from resolv.conf: %w", err) } - if !isElement(conf.Search, fmt.Sprintf("svc.%s", opts.zone)) { - log.Warn().Msgf("Unabled to validate k8s zone (%s)", opts.zone) + if !util.IsElement(conf.Search, fmt.Sprintf("svc.%s", d.CliOpts.Zone)) { + log.Warn().Msgf("Unabled to validate k8s zone (%s)", d.CliOpts.Zone) } port, err := strconv.Atoi(conf.Port) @@ -43,7 +62,7 @@ func getNSFromSystem() (string, int, error) { return conf.Servers[0], port, nil } -func queryPTR(ip net.IP) (*queryResult, error) { +func (d *DNSClient) QueryPTR(ip net.IP) (*types.QueryResult, error) { m := &dns.Msg{ Question: make([]dns.Question, 1), @@ -52,7 +71,7 @@ func queryPTR(ip net.IP) (*queryResult, error) { }, } - revip := strings.Join(reverse(strings.Split(ip.String(), ".")), ".") + revip := strings.Join(util.Reverse(strings.Split(ip.String(), ".")), ".") ptr := fmt.Sprintf("%s.in-addr.arpa.", revip) fqdn := dns.Fqdn(ptr) @@ -62,7 +81,7 @@ func queryPTR(ip net.IP) (*queryResult, error) { Qtype: dns.TypePTR, Qclass: dns.ClassINET, } - res, err := multiProtoQueryRecord(m) + res, err := d.MultiProtoQueryRecord(m) if err != nil { return nil, err } @@ -70,16 +89,16 @@ func queryPTR(ip net.IP) (*queryResult, error) { return nil, nil } - return &queryResult{ - answers: res.answers, - additional: res.additional, - raw: res.raw, - ip: &ip, - rtt: res.rtt, + return &types.QueryResult{ + Answers: res.Answers, + Additional: res.Additional, + Raw: res.Raw, + IP: &ip, + RTT: res.RTT, }, nil } -func queryA(aname string) (*queryResult, error) { +func (d *DNSClient) QueryA(aname string) (*types.QueryResult, error) { m := &dns.Msg{ Question: make([]dns.Question, 1), MsgHdr: dns.MsgHdr{ @@ -93,10 +112,10 @@ func queryA(aname string) (*queryResult, error) { Qtype: dns.TypeA, Qclass: dns.ClassINET, } - return multiProtoQueryRecord(m) + return d.MultiProtoQueryRecord(m) } -func querySRV(aname string) (*queryResult, error) { +func (d *DNSClient) QuerySRV(aname string) (*types.QueryResult, error) { m := &dns.Msg{ Question: make([]dns.Question, 1), MsgHdr: dns.MsgHdr{ @@ -110,10 +129,10 @@ func querySRV(aname string) (*queryResult, error) { Qtype: dns.TypeSRV, Qclass: dns.ClassINET, } - return multiProtoQueryRecord(m) + return d.MultiProtoQueryRecord(m) } -func queryTXT(txt string) (*queryResult, error) { +func (d *DNSClient) QueryTXT(txt string) (*types.QueryResult, error) { m := &dns.Msg{ Question: make([]dns.Question, 1), MsgHdr: dns.MsgHdr{ @@ -127,24 +146,24 @@ func queryTXT(txt string) (*queryResult, error) { Qtype: dns.TypeTXT, Qclass: dns.ClassINET, } - return multiProtoQueryRecord(m) + return d.MultiProtoQueryRecord(m) } -func multiProtoQueryRecord(m *dns.Msg) (*queryResult, error) { - switch opts.proto { +func (d *DNSClient) MultiProtoQueryRecord(m *dns.Msg) (*types.QueryResult, error) { + switch d.CliOpts.Proto { case "auto": - return autoProtoQueryRecord(m) + return d.AutoProtoQueryRecord(m) case "udp": - return queryRecord(clientUDP, m) + return d.queryRecord(d.UDPClient, m) case "tcp": - return queryRecord(clientTCP, m) + return d.queryRecord(d.TCPClient, m) default: - return nil, fmt.Errorf("Unknown protocol %s", opts.proto) + return nil, fmt.Errorf("Unknown protocol %s", d.CliOpts.Proto) } } -func autoProtoQueryRecord(m *dns.Msg) (*queryResult, error) { - res, err := queryRecord(clientUDP, m) +func (d *DNSClient) AutoProtoQueryRecord(m *dns.Msg) (*types.QueryResult, error) { + res, err := d.queryRecord(d.UDPClient, m) if err != nil { return nil, err } @@ -153,17 +172,17 @@ func autoProtoQueryRecord(m *dns.Msg) (*queryResult, error) { return nil, nil } - if !res.raw.Truncated { + if !res.Raw.Truncated { return res, nil } log.Debug().Msgf("Got truncated response for %s, retrying with TCP", m.Question[0].Name) - return queryRecord(clientTCP, m) + return d.queryRecord(d.TCPClient, m) } -func queryRecord(client *dns.Client, m *dns.Msg) (*queryResult, error) { - r, rtt, err := client.Exchange(m, fmt.Sprintf("%s:%d", opts.nameserver, opts.nameport)) +func (d *DNSClient) queryRecord(client *dns.Client, m *dns.Msg) (*types.QueryResult, error) { + r, rtt, err := client.Exchange(m, fmt.Sprintf("%s:%d", d.CliOpts.Nameserver, d.CliOpts.Nameport)) if err != nil { var dnsError *net.OpError if errors.As(err, &dnsError) && strings.Contains(err.Error(), "timeout") { @@ -173,17 +192,17 @@ func queryRecord(client *dns.Client, m *dns.Msg) (*queryResult, error) { } if r != nil && len(r.Answer) > 0 { - return &queryResult{ - answers: r.Answer, - additional: r.Extra, - raw: r, - rtt: &rtt, + return &types.QueryResult{ + Answers: r.Answer, + Additional: r.Extra, + Raw: r, + RTT: &rtt, }, nil } return nil, nil } -func parseSRVAnswer(ans string) (string, string, int, error) { +func ParseSRVAnswer(ans string) (string, string, int, error) { parts := strings.Split(ans, "\t") if len(parts) != 5 { return "", "", 0, fmt.Errorf("Error parsing SRV: %s", ans) @@ -196,16 +215,16 @@ func parseSRVAnswer(ans string) (string, string, int, error) { if err != nil { return "", "", 0, err } - name, ns := parseDNSPodName(parts4[3]) + name, ns := ParseDNSPodName(parts4[3]) return name, ns, port, nil } -func parseAAnswer(ans string) (string, string, net.IP, error) { +func ParseAAnswer(ans string) (string, string, net.IP, error) { parts := strings.Split(ans, "\t") if len(parts) != 5 { return "", "", nil, fmt.Errorf("Error parsing A: %s", ans) } - name, ns := parseDNSPodName(parts[0]) + name, ns := ParseDNSPodName(parts[0]) ip := net.ParseIP(parts[4]) if ip == nil { return "", "", nil, fmt.Errorf("Error parsing IP address: %s", parts[4]) @@ -213,7 +232,7 @@ func parseAAnswer(ans string) (string, string, net.IP, error) { return name, ns, ip, nil } -func parseDNSPodName(fqdn string) (string, string) { +func ParseDNSPodName(fqdn string) (string, string) { parts := strings.Split(fqdn, ".") if len(parts) == 7 { diff --git a/cmd/dns_test.go b/pkg/dnsclient/dns_test.go similarity index 93% rename from cmd/dns_test.go rename to pkg/dnsclient/dns_test.go index 73c472e..2431138 100644 --- a/cmd/dns_test.go +++ b/pkg/dnsclient/dns_test.go @@ -1,4 +1,4 @@ -package cmd +package dnsclient import ( "net" @@ -31,7 +31,7 @@ func TestParseAAnswer(t *testing.T) { in := tcase["in"].(string) out := tcase["out"].(map[string]any) - name, ns, ip, err := parseAAnswer(in) + name, ns, ip, err := ParseAAnswer(in) assert.Equal(t, out["name"], name) assert.Equal(t, out["ns"], ns) @@ -64,7 +64,7 @@ func TestParseSRVAnswer(t *testing.T) { in := tcase["in"].(string) out := tcase["out"].(map[string]any) - name, ns, port, err := parseSRVAnswer(in) + name, ns, port, err := ParseSRVAnswer(in) assert.Equal(t, out["name"], name) assert.Equal(t, out["ns"], ns) diff --git a/pkg/dnsclient/mode.go b/pkg/dnsclient/mode.go new file mode 100644 index 0000000..344ddb2 --- /dev/null +++ b/pkg/dnsclient/mode.go @@ -0,0 +1,61 @@ +package dnsclient + +import ( + "fmt" + + "github.com/rs/zerolog/log" +) + +const ( + MODE_AUTO = "auto" + MODE_BRUTEFORCE = "bruteforce" + MODE_WILDCARD = "wildcard" + MODE_FAILED = "failed" +) + +func (d *DNSClient) DetectMode() string { + if ok, _ := d.CheckSpecVersion(); !ok { + log.Info().Msg("Unable to detect spec compliant Kubernetes DNS server") + return MODE_FAILED + } + + if ok, _ := d.CheckWildcardK8sAddress(); ok { + log.Info().Msg("Wildcard support detected") + return MODE_WILDCARD + } + + if ok, _ := d.CheckDefaultK8sAddress(); ok { + log.Info().Msg("Falling back to bruteforce mode") + return MODE_BRUTEFORCE + } + + log.Error().Msg("Failed to detect a CoreDNS server") + return MODE_FAILED +} + +func (d *DNSClient) CheckSpecVersion() (bool, error) { + res, err := d.QueryTXT(fmt.Sprintf("dns-version.%s", d.CliOpts.Zone)) + if err != nil { + return false, err + } + + return res != nil, nil +} + +func (d *DNSClient) CheckDefaultK8sAddress() (bool, error) { + res, err := d.QueryA(fmt.Sprintf("kubernetes.default.svc.%s", d.CliOpts.Zone)) + if err != nil { + return false, err + } + + return res != nil, nil +} + +func (d *DNSClient) CheckWildcardK8sAddress() (bool, error) { + res, err := d.QueryA(fmt.Sprintf("any.any.svc.%s", d.CliOpts.Zone)) + if err != nil { + return false, err + } + + return res != nil, nil +} diff --git a/cmd/brute.go b/pkg/scanners/brute.go similarity index 59% rename from cmd/brute.go rename to pkg/scanners/brute.go index 9d5cd19..94cab27 100644 --- a/cmd/brute.go +++ b/pkg/scanners/brute.go @@ -1,4 +1,4 @@ -package cmd +package scanners import ( "fmt" @@ -6,10 +6,12 @@ import ( "strings" "sync" + "github.com/rs/zerolog/log" "github.com/seancfoley/ipaddress-go/ipaddr" - "github.com/seancfoley/ipaddress-go/ipaddr/addrstrparam" - "github.com/rs/zerolog/log" + "github.com/jpts/coredns-enum/internal/types" + "github.com/jpts/coredns-enum/internal/util" + "github.com/jpts/coredns-enum/pkg/dnsclient" ) var srvServices = map[string][]string{ @@ -60,15 +62,15 @@ var srvServices = map[string][]string{ var ipChan = make(chan net.IP) var srvChan = make(chan net.IP) -var ptrResultChan = make(chan queryResult) -var svcChan = make(chan svcResult) -var svcResultChan = make(chan svcResult) +var ptrResultChan = make(chan types.QueryResult) +var svcChan = make(chan types.SvcResult) +var svcResultChan = make(chan types.SvcResult) -func brute(opts *cliOpts) ([]*svcResult, error) { +func BruteScan(opts *types.CliOpts, dclient *dnsclient.DNSClient) ([]*types.SvcResult, error) { var subnets []*ipaddr.IPAddress - if opts.cidrRange == "" { - cert, err := getDefaultAPIServerCert() + if opts.CidrRange == "" { + cert, err := GetDefaultAPIServerCert(dclient.CliOpts.Zone) if err != nil { return nil, err } @@ -79,29 +81,13 @@ func brute(opts *cliOpts) ([]*svcResult, error) { log.Info().Msgf("Guessed %s CIDRs from APIserver cert", subnets) } else { - for _, cidr := range strings.Split(opts.cidrRange, ",") { - pb := addrstrparam.IPAddressStringParamsBuilder{} - pb.AllowWildcardedSeparator(true) - pb.AllowIPv4(true) - pb.AllowIPv6(false) - pb.AllowMask(true) - pb.AllowPrefix(true) - pb.AllowEmpty(false) - pb.AllowSingleSegment(false) - params := pb.ToParams() - - ipastr := ipaddr.NewIPAddressStringParams(cidr, params) - - if !ipastr.IsPrefixed() { - return nil, fmt.Errorf("CIDR %s requires prefix, use /32 for a single host", cidr) - } - - subnet, err := ipastr.ToAddress() + for _, cidr := range strings.Split(opts.CidrRange, ",") { + subnet, err := util.ParseIPv4CIDR(cidr) if err != nil { return nil, err } - subnets = append(subnets, subnet.ToPrefixBlock()) + subnets = append(subnets, subnet) } } @@ -118,11 +104,11 @@ func brute(opts *cliOpts) ([]*svcResult, error) { close(ipChan) }() - // parallelise ptr query scanning + // parallelise ptr dnsclient.Query scanning wg := sync.WaitGroup{} - wg.Add(opts.maxWorkers) - for w := 0; w < opts.maxWorkers; w++ { - go ptrQueryWorker(&wg) + wg.Add(opts.MaxWorkers) + for w := 0; w < opts.MaxWorkers; w++ { + go ptrQueryWorker(&wg, dclient) } go func() { @@ -133,15 +119,15 @@ func brute(opts *cliOpts) ([]*svcResult, error) { // recv results async go func() { for res := range ptrResultChan { - if res.answers != nil { - ans := res.answers[0] + if res.Answers != nil { + ans := res.Answers[0] parts := strings.Split(ans.String(), "\t") - name, ns := parseDNSPodName(parts[len(parts)-1]) - log.Debug().Msgf("Processing svc: %s\t%s", res.ip, parts[len(parts)-1]) - svc := svcResult{ + name, ns := dnsclient.ParseDNSPodName(parts[len(parts)-1]) + log.Debug().Msgf("Processing svc: %s\t%s", res.IP, parts[len(parts)-1]) + svc := types.SvcResult{ Name: name, Namespace: ns, - IP: res.ip, + IP: res.IP, } svcChan <- svc } @@ -151,9 +137,9 @@ func brute(opts *cliOpts) ([]*svcResult, error) { // parallelise service port bruteforcing wg2 := sync.WaitGroup{} - wg2.Add(opts.maxWorkers) - for w := 0; w < opts.maxWorkers; w++ { - go svcPortScanWorker(&wg2) + wg2.Add(opts.MaxWorkers) + for w := 0; w < opts.MaxWorkers; w++ { + go svcPortScanWorker(&wg2, dclient) } go func() { @@ -161,10 +147,10 @@ func brute(opts *cliOpts) ([]*svcResult, error) { close(svcResultChan) }() - var svcs []*svcResult + var svcs []*types.SvcResult for res := range svcResultChan { // new object needed as objects are clobbered in channel range - obj := &svcResult{ + obj := &types.SvcResult{ Name: res.Name, Namespace: res.Namespace, IP: res.IP, @@ -177,9 +163,9 @@ func brute(opts *cliOpts) ([]*svcResult, error) { return svcs, nil } -func ptrQueryWorker(wg *sync.WaitGroup) { +func ptrQueryWorker(wg *sync.WaitGroup, dclient *dnsclient.DNSClient) { for ip := range ipChan { - res, err := queryPTR(ip) + res, err := dclient.QueryPTR(ip) if err != nil { log.Info().Msgf("Retrying failed ip %s: %s", ip, err.Error()) ipChan <- ip @@ -192,16 +178,16 @@ func ptrQueryWorker(wg *sync.WaitGroup) { wg.Done() } -func svcPortScanWorker(wg *sync.WaitGroup) { +func svcPortScanWorker(wg *sync.WaitGroup, dclient *dnsclient.DNSClient) { for svc := range svcChan { for proto, srvSvcList := range srvServices { for _, svcName := range srvSvcList { - res, err := querySRV(fmt.Sprintf("%s._%s.%s.%s.svc.%s", + res, err := dclient.QuerySRV(fmt.Sprintf("%s._%s.%s.%s.svc.%s", svcName, proto, svc.Name, svc.Namespace, - opts.zone, + dclient.CliOpts.Zone, )) if err != nil { log.Warn().Msgf("SRV request failed %s/%s: %s", svcName, proto, err.Error()) @@ -210,8 +196,8 @@ func svcPortScanWorker(wg *sync.WaitGroup) { if res == nil { continue } - for _, ans := range res.answers { - _, _, port, err := parseSRVAnswer(ans.String()) + for _, ans := range res.Answers { + _, _, port, err := dnsclient.ParseSRVAnswer(ans.String()) if err != nil { log.Warn().Err(err) continue @@ -227,20 +213,3 @@ func svcPortScanWorker(wg *sync.WaitGroup) { wg.Done() } - -func reverse(numbers []string) []string { - newNumbers := make([]string, len(numbers)) - for i, j := 0, len(numbers)-1; i <= j; i, j = i+1, j-1 { - newNumbers[i], newNumbers[j] = numbers[j], numbers[i] - } - return newNumbers -} - -func isElement(s []string, str string) bool { - for _, v := range s { - if v == str { - return true - } - } - return false -} diff --git a/cmd/mode.go b/pkg/scanners/mode.go similarity index 57% rename from cmd/mode.go rename to pkg/scanners/mode.go index 313aa1b..84503fb 100644 --- a/cmd/mode.go +++ b/pkg/scanners/mode.go @@ -1,4 +1,4 @@ -package cmd +package scanners import ( "crypto/tls" @@ -9,67 +9,11 @@ import ( "github.com/seancfoley/ipaddress-go/ipaddr" ) -const ( - MODE_AUTO = "auto" - MODE_BRUTEFORCE = "bruteforce" - MODE_WILDCARD = "wildcard" - MODE_FAILED = "failed" -) - -func detectMode() string { - if ok, _ := checkSpecVersion(); !ok { - log.Info().Msg("Unable to detect spec compliant Kubernetes DNS server") - return MODE_FAILED - } - - if ok, _ := wildcardK8sAddress(); ok { - log.Info().Msg("Wildcard support detected") - return MODE_WILDCARD - } - if ok, _ := queryDefaultK8sAddress(); ok { - log.Info().Msg("Falling back to bruteforce mode") - return MODE_BRUTEFORCE - } - log.Error().Msg("Failed to detect a CoreDNS server") - return MODE_FAILED -} - -func checkSpecVersion() (bool, error) { - res, err := queryTXT(fmt.Sprintf("dns-version.%s", opts.zone)) - if err != nil { - return false, err - } - - return res != nil, nil -} - -func queryDefaultK8sAddress() (bool, error) { - res, err := queryA(fmt.Sprintf("kubernetes.default.svc.%s", opts.zone)) - if err != nil { - return false, err - } - - return res != nil, nil -} - -func wildcardK8sAddress() (bool, error) { - res, err := queryA(fmt.Sprintf("any.any.svc.%s", opts.zone)) - if err != nil { - return false, err - } - - return res != nil, nil -} - -func getDefaultAPIServerCert() (*x509.Certificate, error) { - if ok, err := queryDefaultK8sAddress(); !ok || err != nil { - return nil, fmt.Errorf("couldnt query default apiserver") - } - +func GetDefaultAPIServerCert(zone string) (*x509.Certificate, error) { conf := &tls.Config{ InsecureSkipVerify: true, } - conn, err := tls.Dial("tcp", fmt.Sprintf("kubernetes.default.svc.%s:443", opts.zone), conf) + conn, err := tls.Dial("tcp", fmt.Sprintf("kubernetes.default.svc.%s:443", zone), conf) if err != nil { return nil, fmt.Errorf("Error in connecting to API server") } diff --git a/cmd/mode_test.go b/pkg/scanners/mode_test.go similarity index 99% rename from cmd/mode_test.go rename to pkg/scanners/mode_test.go index 8758755..3855223 100644 --- a/cmd/mode_test.go +++ b/pkg/scanners/mode_test.go @@ -1,4 +1,4 @@ -package cmd +package scanners import ( "crypto/rand" diff --git a/cmd/wildcard.go b/pkg/scanners/wildcard.go similarity index 56% rename from cmd/wildcard.go rename to pkg/scanners/wildcard.go index 85d5a72..6513b20 100644 --- a/cmd/wildcard.go +++ b/pkg/scanners/wildcard.go @@ -1,35 +1,37 @@ -package cmd +package scanners import ( "fmt" "strings" + "github.com/jpts/coredns-enum/internal/types" + "github.com/jpts/coredns-enum/pkg/dnsclient" "github.com/rs/zerolog/log" ) -// https://github.com/coredns/coredns.io/blob/1.8.4/content/plugins/kubernetes.md#wildcards +// https://github.com/coredns/corednsclient.io/blob/1.8.4/content/plugins/kubernetes.md#wildcards -func wildcard(opts *cliOpts) ([]*svcResult, error) { - var svcs []*svcResult +func WildcardScan(opts *types.CliOpts, dclient *dnsclient.DNSClient) ([]*types.SvcResult, error) { + var svcs []*types.SvcResult // port/proto - gives us namespaces for _, proto := range []string{"tcp", "udp"} { - res, err := querySRV(fmt.Sprintf("any._%s.any.any.svc.%s", proto, opts.zone)) + res, err := dclient.QuerySRV(fmt.Sprintf("any._%s.any.any.svc.%s", proto, opts.Zone)) if err != nil { return nil, err } - if res == nil || res.additional == nil { + if res == nil || res.Additional == nil { log.Debug().Msgf("No svcs for proto %s found", proto) continue } - for _, rr := range res.additional { - name, ns, ip, err := parseAAnswer(rr.String()) + for _, rr := range res.Additional { + name, ns, ip, err := dnsclient.ParseAAnswer(rr.String()) if err != nil { return nil, err } - svc := &svcResult{ + svc := &types.SvcResult{ Name: name, Namespace: ns, IP: &ip, @@ -37,12 +39,12 @@ func wildcard(opts *cliOpts) ([]*svcResult, error) { svcs, _ = addUniqueSvcToSvcs(svcs, svc) } - if res.answers == nil { + if res.Answers == nil { log.Debug().Msgf("No named ports for %s svcs found", proto) continue } - for _, rr := range res.answers { - name, ns, port, err := parseSRVAnswer(rr.String()) + for _, rr := range res.Answers { + name, ns, port, err := dnsclient.ParseSRVAnswer(rr.String()) if err != nil { return nil, err } @@ -52,23 +54,23 @@ func wildcard(opts *cliOpts) ([]*svcResult, error) { // endpoints for _, svc := range svcs { - res, err := queryA(fmt.Sprintf("any.%s.%s.svc.%s", svc.Name, svc.Namespace, opts.zone)) + res, err := dclient.QueryA(fmt.Sprintf("any.%s.%s.svc.%s", svc.Name, svc.Namespace, opts.Zone)) if err != nil { log.Warn().Err(err) continue } - if res == nil || res.answers == nil { + if res == nil || res.Answers == nil { log.Debug().Msgf("svc %s/%s has no registered endpoints", svc.Namespace, svc.Name) continue } - for _, rr := range res.answers { - _, _, ip, err := parseAAnswer(rr.String()) + for _, rr := range res.Answers { + _, _, ip, err := dnsclient.ParseAAnswer(rr.String()) if err != nil { log.Warn().Err(err) continue } - endp := &podResult{ + endp := &types.PodResult{ Name: svc.Name, Namespace: svc.Namespace, IP: &ip, @@ -80,7 +82,7 @@ func wildcard(opts *cliOpts) ([]*svcResult, error) { return svcs, nil } -func addUniqueSvcToSvcs(svcs []*svcResult, svc *svcResult) ([]*svcResult, error) { +func addUniqueSvcToSvcs(svcs []*types.SvcResult, svc *types.SvcResult) ([]*types.SvcResult, error) { for _, s := range svcs { if s.Name == svc.Name && s.Namespace == svc.Namespace && s.IP.String() == svc.IP.String() { log.Debug().Msgf("svc %s/%s already registered", svc.Namespace, svc.Name) @@ -91,7 +93,7 @@ func addUniqueSvcToSvcs(svcs []*svcResult, svc *svcResult) ([]*svcResult, error) return append(svcs, svc), nil } -func addPortToSvcs(svcs []*svcResult, podName string, ns string, proto string, port int, portName string) error { +func addPortToSvcs(svcs []*types.SvcResult, podName string, ns string, proto string, port int, portName string) error { for _, s := range svcs { if s.Name == podName && s.Namespace == ns { return addPortToSvc(s, proto, port, portName) @@ -101,8 +103,8 @@ func addPortToSvcs(svcs []*svcResult, podName string, ns string, proto string, p return nil } -func addPortToSvc(svc *svcResult, proto string, port int, portName string) error { - p := &portResult{ +func addPortToSvc(svc *types.SvcResult, proto string, port int, portName string) error { + p := &types.PortResult{ Proto: strings.TrimPrefix(proto, "_"), PortNo: port, PortName: strings.TrimPrefix(portName, "_"),