Skip to content

Commit

Permalink
ref: Refactor doh client and segregate config options.
Browse files Browse the repository at this point in the history
  • Loading branch information
iamd3vil committed Aug 2, 2021
1 parent 75cd2af commit d4694ce
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 103 deletions.
13 changes: 13 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package main

import (
"time"

"github.com/miekg/dns"
)

// DNSInCache is serialized and stored in the cache
type DNSInCache struct {
Msg *dns.Msg
CreatedAt time.Time
}
7 changes: 7 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package main

import "github.com/miekg/dns"

type DNSClient interface {
GetDNSResponse(m *dns.Msg) (*dns.Msg, error)
}
17 changes: 13 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,22 @@ type cfgResolver struct {
Urls []string `koanf:"urls"`
}

type cfgCache struct {
Cache bool `koanf:"cache"`
MaxItems int `koanf:"max_items"`
}

type cfgLog struct {
LogLevel string `koanf:"log_level"`
LogQueries bool `koanf:"log_queries"`
}

type Config struct {
BindAddress string `koanf:"bind_address"`
Cache bool `koanf:"cache"`
LogLevel string `koanf:"log_level"`
LogQueries bool `koanf:"log_queries"`
Resolver cfgResolver `koanf:"resolver"`
BootstrapAddress string `koanf:"bootstrap_address"`
Cache cfgCache `koanf:"cache"`
Resolver cfgResolver `koanf:"resolver"`
Log cfgLog `koanf:"log"`
}

func initConfig(cfgPath string) (Config, error) {
Expand Down
23 changes: 15 additions & 8 deletions config.sample.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
# Address for danse to listen on.
bind_address = "127.0.0.1:5454"

# Only used for resolving resolver url. No queries received by danse will be sent here. Default is 9.9.9.9:53
bootstrap_address = "1.1.1.1:53"

# Urls for resolvers.
[resolver]
urls = ["https://dns.quad9.net/dns-query", "https://cloudflare-dns.com/dns-query"]


[cache]
# Should the answers be cached according to ttl. Default is true.
cache = true

# Maximum records to cache.
max_items = 10000

# Config for logging
[log]
# Log level
log_level = "info"

# Logs all queries to stdout. Default is false.
log_queries = true

# Only used for resolving resolver url. No queries received by danse will be sent here. Default is 9.9.9.9:53
bootstrap_address = "1.1.1.1:53"

# Urls for resolvers.
[resolver]
urls = ["https://dns.quad9.net/dns-query", "https://cloudflare-dns.com/dns-query"]
log_queries = true
74 changes: 0 additions & 74 deletions dns.go

This file was deleted.

77 changes: 77 additions & 0 deletions doh_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package main

import (
"bytes"
"errors"
"io/ioutil"
"log"
"net/http"
"sync"

"github.com/miekg/dns"
)

type DohClient struct {
httpClient *http.Client

// Slice of URLs
urls []string

// Last used index
lIndex int

logQueries bool

sync.Mutex
}

func (c *DohClient) GetDNSResponse(msg *dns.Msg) (*dns.Msg, error) {
b, err := msg.Pack()
if err != nil {
return &dns.Msg{}, err
}

c.Lock()

url := c.urls[c.lIndex]

// Increase last index
c.lIndex++

if c.lIndex == len(c.urls) {
c.lIndex = 0
}

c.Unlock()

if c.logQueries {
log.Printf("Sending to %s for query: %s", url, msg.Question[0].String())
}

resp, err := c.httpClient.Post(url, "application/dns-message", bytes.NewBuffer(b))
if err != nil {
return &dns.Msg{}, err
}
if resp.StatusCode != http.StatusOK {
log.Printf("Response from DOH provider has status code: %d", resp.StatusCode)
return &dns.Msg{}, errors.New("error from DOH provider")
}

body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return &dns.Msg{}, nil
}

r := &dns.Msg{}
err = r.Unpack(body)

return r, err
}

func NewDOHClient(c *http.Client, urls []string, logQueries bool) (*DohClient, error) {
return &DohClient{
httpClient: c,
urls: urls,
logQueries: logQueries,
}, nil
}
48 changes: 31 additions & 17 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,48 @@ package main
import (
"context"
"flag"
"fmt"
"log"
"math"
"net"
"net/http"
"os"
"time"

lru "github.com/hashicorp/golang-lru"
"github.com/miekg/dns"
)

var buildString string

type env struct {
httpClient *http.Client
dnsUrls *dnsURLs
cache *lru.Cache
cfg Config
cache *lru.Cache
cfg Config
client DNSClient
}

func main() {
cfgPath := flag.String("config", "config.toml", "Path to config file")
version := flag.Bool("version", false, "Version")
flag.Parse()

if *version {
fmt.Println(buildString)
os.Exit(0)
}

cfg, err := initConfig(*cfgPath)
if err != nil {
log.Fatalf("error reading config: %v", err)
}

maxCacheItems := 512
if cfg.Cache.MaxItems != 0 {
maxCacheItems = cfg.Cache.MaxItems
}

// Initialize cache
cache, err := lru.New(512)
cache, err := lru.New(maxCacheItems)
if err != nil {
log.Fatalln("Couldn't create cache: ", err)
}
Expand Down Expand Up @@ -60,17 +74,17 @@ func main() {
Transport: transport,
}

dnsServer := &dns.Server{Addr: cfg.BindAddress, Net: "udp"}

dnsurls := &dnsURLs{
urls: cfg.Resolver.Urls,
dnsClient, err := NewDOHClient(httpClient, cfg.Resolver.Urls, cfg.Log.LogQueries)
if err != nil {
log.Fatalf("error creating doh client: %v", err)
}

dnsServer := &dns.Server{Addr: cfg.BindAddress, Net: "udp"}

e := env{
httpClient: httpClient,
dnsUrls: dnsurls,
cache: cache,
cfg: cfg,
cache: cache,
cfg: cfg,
client: dnsClient,
}

dns.HandleFunc(".", e.handleDNS)
Expand All @@ -83,13 +97,13 @@ func (e *env) handleDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}

if e.cfg.LogQueries {
if e.cfg.Log.LogQueries {
log.Println("Got DNS query for ", r.Question[0].String())
}

cacheKey := r.Question[0].String()
// Check if we should check cache or not
if !e.cfg.Cache {
if !e.cfg.Cache.Cache {
e.getAndSendResponse(w, r, cacheKey)
return
}
Expand Down Expand Up @@ -119,14 +133,14 @@ func (e *env) handleDNS(w dns.ResponseWriter, r *dns.Msg) {
}

func (e *env) getAndSendResponse(w dns.ResponseWriter, r *dns.Msg, cacheKey string) {
respMsg, err := e.GetDNSResponse(r, e.httpClient, e.dnsUrls)
respMsg, err := e.client.GetDNSResponse(r)
if err != nil {
log.Printf("Something wrong with resp: %v", err)
return
}

// Put it in cache
if e.cfg.Cache {
if e.cfg.Cache.Cache {
dnsCache := DNSInCache{Msg: respMsg, CreatedAt: time.Now()}
e.cache.Add(cacheKey, dnsCache)
}
Expand Down

0 comments on commit d4694ce

Please sign in to comment.