diff --git a/main.go b/main.go index 3a1e9ac..2169485 100644 --- a/main.go +++ b/main.go @@ -71,8 +71,6 @@ func main() { os.Exit(1) } - question := getQuestion() - if timeoutStr != "" { i, err := strconv.Atoi(timeoutStr) if err != nil { @@ -85,18 +83,35 @@ func main() { } var server string - if len(os.Args) > 2 { - server = os.Args[2] - } else { + var domain string + if len(os.Args) == 2 { sysr, err := sysresolv.NewSystemResolvers(nil, 53) if err != nil { log.Printf("Cannot get system resolvers: %v", err) os.Exit(1) } - server = sysr.Addrs()[0].String() + domain = os.Args[1] + } else { + if os.Args[1][0] == '@' { + if os.Args[2][0] == '@' { + log.Printf("There are two dns servers") + usage() + os.Exit(1) + } + server = os.Args[1][1:] + domain = os.Args[2] + } else if os.Args[2][0] == '@' { + server = os.Args[2][1:] + domain = os.Args[1] + } else { + domain = os.Args[1] + server = os.Args[2] + } } + question := getQuestion(domain) + var httpVersions []upstream.HTTPVersion if http3Enabled { httpVersions = []upstream.HTTPVersion{ @@ -236,8 +251,7 @@ func getEDNSOpt() (option *dns.EDNS0_LOCAL) { } // getQuestion returns a DNS question for the query. -func getQuestion() (q dns.Question) { - domain := os.Args[1] +func getQuestion(domain string) (q dns.Question) { rrType := getRRType() qClass := getClass() @@ -354,6 +368,8 @@ func getRRType() (rrType uint16) { func usage() { _, _ = os.Stdout.WriteString("Usage: dnslookup [ ]\n") + _, _ = os.Stdout.WriteString(" or: dnslookup @ [ ]\n") + _, _ = os.Stdout.WriteString(" or: dnslookup @ [ ]\n") _, _ = os.Stdout.WriteString(": mandatory, domain name to lookup\n") _, _ = os.Stdout.WriteString(": mandatory, server address. Supported: plain, tcp:// (TCP), tls:// (DOT), https:// (DOH), sdns:// (DNSCrypt), quic:// (DOQ)\n") _, _ = os.Stdout.WriteString(": optional, DNSCrypt provider name\n")