diff --git a/README.md b/README.md index ce7148baf..b531677a6 100644 --- a/README.md +++ b/README.md @@ -76,8 +76,9 @@ Application Options: --dns64-prefix= If specified, this is the DNS64 prefix dnsproxy will be using when it works as a DNS64 server. If not specified, dnsproxy uses the 'Well-Known Prefix' 64:ff9b:: --ipv6-disabled If specified, all AAAA requests will be replied with NoError RCode and empty answer - --bogus-nxdomain= Transform responses that contain at least one of the given IP addresses into NXDOMAIN. Can be - specified multiple times. + --bogus-nxdomain= Transform the responses containing at least a single IP + that matches specified addresses and CIDRs into + NXDOMAIN. Can be specified multiple times. --udp-buf-size= Set the size of the UDP buffer in bytes. A value <= 0 will use the system default. (default: 0) --max-go-routines= Set the maximum number of go routines. A value <= 0 will not not set a maximum. (default: 0) --version Prints the program version @@ -269,10 +270,20 @@ Now even if your IP address is 192.168.0.1 and it's not a public IP, the proxy w ### Bogus NXDomain -This option is similar to dnsmasq `bogus-nxdomain`. If specified, `dnsproxy` transforms responses that contain at least one of the given IP addresses into `NXDOMAIN`. Can be specified multiple times. +This option is similar to dnsmasq `bogus-nxdomain`. `dnsproxy` will transform +responses that contain at least a single IP address which is also specified by +the option into `NXDOMAIN`. Can be specified multiple times. -In the example below, we use AdGuard DNS server that returns `0.0.0.0` for blocked domains, and transform them to `NXDOMAIN`. +In the example below, we use AdGuard DNS server that returns `0.0.0.0` for +blocked domains, and transform them to `NXDOMAIN`. ``` ./dnsproxy -u 94.140.14.14:53 --bogus-nxdomain=0.0.0.0 ``` + +CIDR ranges are supported as well. The following will respond with `NXDOMAIN` +instead of responses containing any IP from `192.168.0.0`-`192.168.255.255`: + +``` +./dnsproxy -u 192.168.0.15:53 --bogus-nxdomain=192.168.0.0/16 +``` diff --git a/fastip/fastest.go b/fastip/fastest.go index 09c1f4f0a..685cbd321 100644 --- a/fastip/fastest.go +++ b/fastip/fastest.go @@ -67,7 +67,16 @@ func (f *FastestAddr) ExchangeFastest(req *dns.Msg, ups []upstream.Upstream) ( } host := strings.ToLower(req.Question[0].Name) - ips := f.extractIPs(replies) + + ips := make([]net.IP, 0, len(replies)) + for _, r := range replies { + for _, rr := range r.Resp.Answer { + ip := proxyutil.IPFromRR(rr) + if ip != nil && !containsIP(ips, ip) { + ips = append(ips, ip) + } + } + } if pingRes := f.pingAll(host, ips); pingRes != nil { return f.prepareReply(pingRes, replies) @@ -88,7 +97,7 @@ func (f *FastestAddr) prepareReply(pingRes *pingResult, replies []upstream.Excha ) { ip := pingRes.ipp.IP for _, r := range replies { - if hasAns(r.Resp, ip) { + if hasInAns(r.Resp, ip) { m = r.Resp u = r.Upstream @@ -128,10 +137,10 @@ func (f *FastestAddr) prepareReply(pingRes *pingResult, replies []upstream.Excha return m, u, nil } -// hasAns returns true if m contains ip in its answer section. -func hasAns(m *dns.Msg, ip net.IP) (ok bool) { +// hasInAns returns true if m contains ip in its Answer section. +func hasInAns(m *dns.Msg, ip net.IP) (ok bool) { for _, rr := range m.Answer { - respIP := proxyutil.GetIPFromDNSRecord(rr) + respIP := proxyutil.IPFromRR(rr) if respIP != nil && respIP.Equal(ip) { return true } @@ -140,20 +149,6 @@ func hasAns(m *dns.Msg, ip net.IP) (ok bool) { return false } -// extractIPs extracts all IP addresses from results. -func (f *FastestAddr) extractIPs(results []upstream.ExchangeAllResult) (ips []net.IP) { - for _, r := range results { - for _, rr := range r.Resp.Answer { - ip := proxyutil.GetIPFromDNSRecord(rr) - if ip != nil && !containsIP(ips, ip) { - ips = append(ips, ip) - } - } - } - - return ips -} - // containsIP returns true if ips contains the ip. func containsIP(ips []net.IP, ip net.IP) (ok bool) { if len(ips) == 0 { diff --git a/main.go b/main.go index acde4dc5e..ecc3bdd4c 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/netutil" "github.com/ameshkov/dnscrypt/v2" goFlags "github.com/jessevdk/go-flags" "gopkg.in/yaml.v3" @@ -150,7 +151,7 @@ type Options struct { IPv6Disabled bool `yaml:"ipv6-disabled" long:"ipv6-disabled" description:"If specified, all AAAA requests will be replied with NoError RCode and empty answer" optional:"yes" optional-value:"true"` // Transform responses that contain at least one of the given IP addresses into NXDOMAIN - BogusNXDomain []string `yaml:"bogus-nxdomain" long:"bogus-nxdomain" description:"Transform responses that contain at least one of the given IP addresses into NXDOMAIN. Can be specified multiple times."` + BogusNXDomain []string `yaml:"bogus-nxdomain" long:"bogus-nxdomain" description:"Transform the responses containing at least a single IP that matches specified addresses and CIDRs into NXDOMAIN. Can be specified multiple times."` // UDP buffer size value UDPBufferSize int `yaml:"udp-buf-size" long:"udp-buf-size" description:"Set the size of the UDP buffer in bytes. A value <= 0 will use the system default."` @@ -346,17 +347,19 @@ func initEDNS(config *proxy.Config, options *Options) { // initBogusNXDomain inits BogusNXDomain structure func initBogusNXDomain(config *proxy.Config, options *Options) { - if len(options.BogusNXDomain) > 0 { - bogusIP := []net.IP{} - for _, s := range options.BogusNXDomain { - ip := net.ParseIP(s) - if ip == nil { - log.Error("Invalid IP: %s", s) - } else { - bogusIP = append(bogusIP, ip) - } + if len(options.BogusNXDomain) == 0 { + return + } + + for _, s := range options.BogusNXDomain { + subnet, err := netutil.ParseSubnet(s) + if err != nil { + log.Error("%s", err) + + continue } - config.BogusNXDomain = bogusIP + + config.BogusNXDomain = append(config.BogusNXDomain, subnet) } } diff --git a/proxy/bogus_nxdomain.go b/proxy/bogus_nxdomain.go index c8fd818fd..9a540d722 100644 --- a/proxy/bogus_nxdomain.go +++ b/proxy/bogus_nxdomain.go @@ -5,24 +5,21 @@ import ( "github.com/miekg/dns" ) -// isBogusNXDomain - checks if the specified DNS message -// contains AT LEAST ONE ip address from the Proxy.BogusNXDomain list -func (p *Proxy) isBogusNXDomain(reply *dns.Msg) bool { - if reply == nil || - len(p.BogusNXDomain) == 0 || - len(reply.Answer) == 0 || - (reply.Question[0].Qtype != dns.TypeA && - reply.Question[0].Qtype != dns.TypeAAAA) { +// isBogusNXDomain returns true if m contains at least a single IP address in +// the Answer section contained in BogusNXDomain subnets of p. +func (p *Proxy) isBogusNXDomain(m *dns.Msg) (ok bool) { + if m == nil || len(p.BogusNXDomain) == 0 || len(m.Question) == 0 { + return false + } else if qt := m.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA { return false } - for _, rr := range reply.Answer { - ip := proxyutil.GetIPFromDNSRecord(rr) + for _, rr := range m.Answer { + ip := proxyutil.IPFromRR(rr) if proxyutil.ContainsIP(p.BogusNXDomain, ip) { return true } } - // No IPs are bogus if we got here return false } diff --git a/proxy/bogus_nxdomain_test.go b/proxy/bogus_nxdomain_test.go index 28917f2c4..98972c539 100644 --- a/proxy/bogus_nxdomain_test.go +++ b/proxy/bogus_nxdomain_test.go @@ -5,72 +5,99 @@ import ( "testing" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestBogusNXDomainTypeA(t *testing.T) { - dnsProxy := createTestProxy(t, nil) - dnsProxy.CacheEnabled = true - dnsProxy.BogusNXDomain = []net.IP{net.ParseIP("4.3.2.1")} +func TestProxy_IsBogusNXDomain(t *testing.T) { + prx := createTestProxy(t, nil) + prx.CacheEnabled = true - u := testUpstream{} - dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{&u} - err := dnsProxy.Start() - assert.Nil(t, err) + prx.BogusNXDomain = []*net.IPNet{{ + IP: net.IP{4, 3, 2, 1}, + Mask: net.CIDRMask(24, netutil.IPv4BitLen), + }, { + IP: net.IPv4(1, 2, 3, 4), + Mask: net.IPv4Mask(255, 0, 0, 0), + }, { + IP: net.IP{10, 11, 12, 13}, + Mask: net.CIDRMask(netutil.IPv4BitLen, netutil.IPv4BitLen), + }, { + IP: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + Mask: net.CIDRMask(120, netutil.IPv6BitLen), + }} - // first request - // upstream answers with a bogus IP - u.aResp = &dns.A{ - Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, - A: net.ParseIP("4.3.2.1"), - } + testCases := []struct { + name string + ans []dns.RR + wantRcode int + }{{ + name: "bogus_subnet", + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, + A: net.ParseIP("4.3.2.1"), + }}, + wantRcode: dns.RcodeNameError, + }, { + name: "bogus_big_subnet", + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, + A: net.ParseIP("1.254.254.254"), + }}, + wantRcode: dns.RcodeNameError, + }, { + name: "bogus_single_ip", + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, + A: net.ParseIP("10.11.12.13"), + }}, + wantRcode: dns.RcodeNameError, + }, { + name: "bogus_6", + ans: []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10}, + AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 99}, + }}, + wantRcode: dns.RcodeNameError, + }, { + name: "non-bogus", + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, + A: net.ParseIP("10.11.12.14"), + }}, + wantRcode: dns.RcodeSuccess, + }, { + name: "non-bogus_6", + ans: []dns.RR{&dns.AAAA{ + Hdr: dns.RR_Header{Rrtype: dns.TypeAAAA, Name: "host.", Ttl: 10}, + AAAA: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15}, + }}, + wantRcode: dns.RcodeSuccess, + }} - clientIP := net.IP{1, 2, 3, 0} - d := DNSContext{} - d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: clientIP, - } - - err = dnsProxy.Resolve(&d) - assert.Nil(t, err) + u := testUpstream{} + prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u} - // check response - assert.NotNil(t, d.Res) - assert.Equal(t, dns.RcodeNameError, d.Res.Rcode) + err := prx.Start() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, prx.Stop) - // second request - // upstream answers with a normal IP - u.aResp = &dns.A{ - Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, - A: net.ParseIP("4.3.2.2"), + d := &DNSContext{ + Req: createHostTestMessage("host"), } - err = dnsProxy.Resolve(&d) - assert.Nil(t, err) + for _, tc := range testCases { + u.ans = tc.ans - // check response - assert.NotNil(t, d.Res) - assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) + t.Run(tc.name, func(t *testing.T) { + err = prx.Resolve(d) + require.NoError(t, err) + require.NotNil(t, d.Res) - // third request - // upstream answers with two IPs, one of them is bogus - u.aRespArr = append(u.aRespArr, &dns.A{ - Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, - A: net.ParseIP("4.3.2.2"), - }) - u.aRespArr = append(u.aRespArr, &dns.A{ - Hdr: dns.RR_Header{Rrtype: dns.TypeA, Name: "host.", Ttl: 10}, - A: net.ParseIP("4.3.2.1"), - }) - - err = dnsProxy.Resolve(&d) - assert.Nil(t, err) - - // check response - assert.NotNil(t, d.Res) - assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) - - _ = dnsProxy.Stop() + assert.Equal(t, tc.wantRcode, d.Res.Rcode) + }) + } } diff --git a/proxy/cache_test.go b/proxy/cache_test.go index 318d55542..a0793229b 100644 --- a/proxy/cache_test.go +++ b/proxy/cache_test.go @@ -364,14 +364,14 @@ func TestCacheExpirationWithTTLOverride(t *testing.T) { d.Req = createHostTestMessage("host") d.Addr = &net.TCPAddr{} - u.aResp = &dns.A{ + u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{ Rrtype: dns.TypeA, Name: "host.", Ttl: 10, }, A: net.IP{4, 3, 2, 1}, - } + }} err = dnsProxy.Resolve(d) require.NoError(t, err) @@ -388,14 +388,14 @@ func TestCacheExpirationWithTTLOverride(t *testing.T) { d.Req = createHostTestMessage("host2") d.Addr = &net.TCPAddr{} - u.aResp = &dns.A{ + u.ans = []dns.RR{&dns.A{ Hdr: dns.RR_Header{ Rrtype: dns.TypeA, Name: "host2.", Ttl: 60, }, A: net.IP{4, 3, 2, 1}, - } + }} err = dnsProxy.Resolve(d) assert.Nil(t, err) diff --git a/proxy/config.go b/proxy/config.go index bdaa5fb28..11185434c 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -83,7 +83,7 @@ type Config struct { // BogusNXDomain - transforms responses that contain at least one of the given IP addresses into NXDOMAIN // Similar to dnsmasq's "bogus-nxdomain" - BogusNXDomain []net.IP + BogusNXDomain []*net.IPNet // Enable EDNS Client Subnet option // DNS requests to the upstream server will contain an OPT record with Client Subnet option. diff --git a/proxy/proxy.go b/proxy/proxy.go index 6803c44b6..a68dd4e69 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -413,7 +413,7 @@ func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) { log.Tracef("Received an empty AAAA response, checking DNS64") reply, u, err = p.checkDNS64(req, reply, upstreams) } else if p.isBogusNXDomain(reply) { - log.Tracef("Received IP from the bogus-nxdomain list, replacing response") + log.Tracef("response ip is contained in bogus-nxdomain list") reply = p.genWithRCode(reply, dns.RcodeNameError) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 188061af3..f95bf4635 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -19,6 +19,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" glcache "github.com/AdguardTeam/golibs/cache" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/testutil" "github.com/ameshkov/dnscrypt/v2" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -821,30 +822,33 @@ func TestProxy_ReplyFromUpstream_badResponse(t *testing.T) { } func TestExchangeCustomUpstreamConfig(t *testing.T) { - dnsProxy := createTestProxy(t, nil) - err := dnsProxy.Start() - assert.True(t, err == nil) - - // this upstream will be used as a custom - u := testUpstream{} - u.aResp = new(dns.A) - u.aResp.Hdr.Rrtype = dns.TypeA - u.aResp.Hdr.Name = "host." - u.aResp.A = net.IP{4, 3, 2, 1} - u.aResp.Hdr.Ttl = 60 - config := &UpstreamConfig{Upstreams: []upstream.Upstream{&u}} - - // test request - d := DNSContext{} - d.CustomUpstreamConfig = config - d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: net.IP{1, 2, 3, 0}, + prx := createTestProxy(t, nil) + err := prx.Start() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, prx.Stop) + + ansIP := net.IP{4, 3, 2, 1} + u := testUpstream{ + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Name: "host.", + Ttl: 60, + }, + A: ansIP, + }}, } - err = dnsProxy.Resolve(&d) - assert.Nil(t, err) - assert.Equal(t, u.aResp.A, getIPFromResponse(d.Res)) + d := DNSContext{ + CustomUpstreamConfig: &UpstreamConfig{Upstreams: []upstream.Upstream{&u}}, + Req: createHostTestMessage("host"), + Addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 0}}, + } + + err = prx.Resolve(&d) + require.NoError(t, err) + + assert.Equal(t, ansIP, getIPFromResponse(d.Res)) } func TestECS(t *testing.T) { @@ -859,41 +863,48 @@ func TestECS(t *testing.T) { // Resolve the same host with the different client subnet values func TestECSProxy(t *testing.T) { - dnsProxy := createTestProxy(t, nil) - dnsProxy.EnableEDNSClientSubnet = true - dnsProxy.CacheEnabled = true - u := testUpstream{} - dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{&u} - err := dnsProxy.Start() - assert.True(t, err == nil) + prx := createTestProxy(t, nil) + prx.EnableEDNSClientSubnet = true + prx.CacheEnabled = true + + u := testUpstream{ + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Name: "host.", + Ttl: 60, + }, + A: net.IP{4, 3, 2, 1}, + }}, + ecsIP: net.IP{1, 2, 3, 0}, + } + prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u} + err := prx.Start() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, prx.Stop) // first request - d := DNSContext{} - d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: net.IP{1, 2, 3, 0}, - } - u.aResp = new(dns.A) - u.aResp.Hdr.Rrtype = dns.TypeA - u.aResp.Hdr.Name = "host." - u.aResp.A = net.IP{4, 3, 2, 1} - u.aResp.Hdr.Ttl = 60 - u.ecsIP = net.IP{1, 2, 3, 0} - err = dnsProxy.Resolve(&d) - assert.True(t, err == nil) + d := DNSContext{ + Req: createHostTestMessage("host"), + Addr: &net.TCPAddr{IP: net.IP{1, 2, 3, 0}}, + } + + err = prx.Resolve(&d) + require.NoError(t, err) + assert.True(t, getIPFromResponse(d.Res).Equal(net.IP{4, 3, 2, 1})) assert.True(t, u.ecsReqIP.Equal(net.IP{1, 2, 3, 0})) // request from another client with the same subnet - must be served from cache d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: net.IP{1, 2, 3, 1}, - } - u.aResp = nil + d.Addr = &net.TCPAddr{IP: net.IP{1, 2, 3, 1}} + u.ans = nil u.ecsIP = nil u.ecsReqIP = nil - err = dnsProxy.Resolve(&d) - assert.True(t, err == nil) + + err = prx.Resolve(&d) + require.NoError(t, err) + assert.True(t, getIPFromResponse(d.Res).Equal(net.IP{4, 3, 2, 1})) assert.True(t, u.ecsReqIP == nil) @@ -902,82 +913,93 @@ func TestECSProxy(t *testing.T) { d.Addr = &net.TCPAddr{ IP: net.IP{2, 2, 3, 0}, } - u.aResp = new(dns.A) - u.aResp.Hdr.Name = "host." - u.aResp.A = net.IP{4, 3, 2, 2} - u.aResp.Hdr.Ttl = 60 + u.ans = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Name: "host.", + Ttl: 60, + }, + A: net.IP{4, 3, 2, 2}, + }} u.ecsIP = net.IP{2, 2, 3, 0} - u.ecsReqIP = nil - err = dnsProxy.Resolve(&d) - assert.True(t, err == nil) + + err = prx.Resolve(&d) + require.NoError(t, err) + assert.True(t, getIPFromResponse(d.Res).Equal(net.IP{4, 3, 2, 2})) assert.True(t, u.ecsReqIP.Equal(net.IP{2, 2, 3, 0})) // request from a local IP - store in general (not subnet-aware) cache d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: net.IP{127, 0, 0, 1}, - } - u.aResp = new(dns.A) - u.aResp.Hdr.Rrtype = dns.TypeA - u.aResp.Hdr.Name = "host." - u.aResp.A = net.IP{4, 3, 2, 3} - u.aResp.Hdr.Ttl = 60 + d.Addr = &net.TCPAddr{IP: net.IP{127, 0, 0, 1}} + u.ans = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Name: "host.", + Ttl: 60, + }, + A: net.IP{4, 3, 2, 3}, + }} u.ecsIP = nil u.ecsReqIP = nil - err = dnsProxy.Resolve(&d) - assert.True(t, err == nil) + + err = prx.Resolve(&d) + require.NoError(t, err) + assert.True(t, getIPFromResponse(d.Res).Equal(net.IP{4, 3, 2, 3})) assert.True(t, u.ecsReqIP == nil) // request from another local IP - get from general cache d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: net.IP{127, 0, 0, 2}, - } - u.aResp = nil + d.Addr = &net.TCPAddr{IP: net.IP{127, 0, 0, 2}} + u.ans = nil u.ecsIP = nil u.ecsReqIP = nil - err = dnsProxy.Resolve(&d) - assert.True(t, err == nil) + + err = prx.Resolve(&d) + require.NoError(t, err) + assert.True(t, getIPFromResponse(d.Res).Equal(net.IP{4, 3, 2, 3})) assert.True(t, u.ecsReqIP == nil) - - _ = dnsProxy.Stop() } func TestECSProxyCacheMinMaxTTL(t *testing.T) { - dnsProxy := createTestProxy(t, nil) - dnsProxy.EnableEDNSClientSubnet = true - dnsProxy.CacheEnabled = true - dnsProxy.CacheMinTTL = 20 - dnsProxy.CacheMaxTTL = 40 - u := testUpstream{} - dnsProxy.UpstreamConfig.Upstreams = []upstream.Upstream{&u} - err := dnsProxy.Start() - assert.True(t, err == nil) + clientIP := net.IP{1, 2, 3, 0} + + prx := createTestProxy(t, nil) + prx.EnableEDNSClientSubnet = true + prx.CacheEnabled = true + prx.CacheMinTTL = 20 + prx.CacheMaxTTL = 40 + u := testUpstream{ + ans: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Name: "host.", + Ttl: 10, + }, + A: net.IP{4, 3, 2, 1}, + }}, + ecsIP: clientIP, + } + prx.UpstreamConfig.Upstreams = []upstream.Upstream{&u} + err := prx.Start() + require.NoError(t, err) + testutil.CleanupAndRequireSuccess(t, prx.Stop) // first request - clientIP := net.IP{1, 2, 3, 0} - d := DNSContext{} - d.Req = createHostTestMessage("host") - d.Addr = &net.TCPAddr{ - IP: clientIP, + d := DNSContext{ + Req: createHostTestMessage("host"), + Addr: &net.TCPAddr{IP: clientIP}, } - u.aResp = new(dns.A) - u.aResp.Hdr.Rrtype = dns.TypeA - u.aResp.Hdr.Name = "host." - u.aResp.A = net.IP{4, 3, 2, 1} - u.aResp.Hdr.Ttl = 10 - u.ecsIP = clientIP - err = dnsProxy.Resolve(&d) + err = prx.Resolve(&d) require.NoError(t, err) // get from cache - check min TTL - ci, expired, key := dnsProxy.cache.getWithSubnet(d.Req, clientIP, 24) + ci, expired, key := prx.cache.getWithSubnet(d.Req, clientIP, 24) assert.False(t, expired) assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24)) - assert.True(t, ci.m.Answer[0].Header().Ttl == dnsProxy.CacheMinTTL) + assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMinTTL) // 2nd request clientIP = net.IP{1, 2, 4, 0} @@ -985,22 +1007,23 @@ func TestECSProxyCacheMinMaxTTL(t *testing.T) { d.Addr = &net.TCPAddr{ IP: clientIP, } - u.aResp = new(dns.A) - u.aResp.Hdr.Rrtype = dns.TypeA - u.aResp.Hdr.Name = "host." - u.aResp.A = net.IP{4, 3, 2, 1} - u.aResp.Hdr.Ttl = 60 + u.ans = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{ + Rrtype: dns.TypeA, + Name: "host.", + Ttl: 60, + }, + A: net.IP{4, 3, 2, 1}, + }} u.ecsIP = clientIP - err = dnsProxy.Resolve(&d) - assert.True(t, err == nil) + err = prx.Resolve(&d) + require.NoError(t, err) // get from cache - check max TTL - ci, expired, key = dnsProxy.cache.getWithSubnet(d.Req, clientIP, 24) + ci, expired, key = prx.cache.getWithSubnet(d.Req, clientIP, 24) assert.False(t, expired) assert.Equal(t, key, msgToKeyWithSubnet(d.Req, clientIP, 24)) - assert.True(t, ci.m.Answer[0].Header().Ttl == dnsProxy.CacheMaxTTL) - - _ = dnsProxy.Stop() + assert.True(t, ci.m.Answer[0].Header().Ttl == prx.CacheMaxTTL) } func createTestDNSCryptProxy(t *testing.T) (*Proxy, dnscrypt.ResolverConfig) { @@ -1230,34 +1253,27 @@ func getIPFromResponse(resp *dns.Msg) net.IP { } type testUpstream struct { - cname1Resp *dns.CNAME - aResp *dns.A - aRespArr []*dns.A + ans []dns.RR + ecsIP net.IP ecsReqIP net.IP ecsReqMask uint8 } -func (u *testUpstream) Exchange(m *dns.Msg) (*dns.Msg, error) { - resp := dns.Msg{} +func (u *testUpstream) Exchange(m *dns.Msg) (resp *dns.Msg, err error) { + resp = &dns.Msg{} resp.SetReply(m) - if u.cname1Resp != nil { - resp.Answer = append(resp.Answer, u.cname1Resp) - } - - resp.Answer = append(resp.Answer, u.aResp) - - for _, a := range u.aRespArr { - resp.Answer = append(resp.Answer, a) + if u.ans != nil { + resp.Answer = append(resp.Answer, u.ans...) } u.ecsReqIP, u.ecsReqMask, _ = parseECS(m) if u.ecsIP != nil { - _, _ = setECS(&resp, u.ecsIP, 24) + _, _ = setECS(resp, u.ecsIP, 24) } - return &resp, nil + return resp, nil } func (u *testUpstream) Address() string { diff --git a/proxyutil/helpers.go b/proxyutil/helpers.go index 3b7ef99a2..abe43daa6 100644 --- a/proxyutil/helpers.go +++ b/proxyutil/helpers.go @@ -6,36 +6,32 @@ import ( "bytes" "net" - "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) -// IsConnClosed returns true if the error signals of a closed server connecting. -// -// Deprecated: This function is deprecated. Use errors.Is(err, net.ErrClosed) -// instead. -func IsConnClosed(err error) bool { - return errors.Is(err, net.ErrClosed) -} - -// GetIPFromDNSRecord - extracts IP address for a DNS record -// returns null if the record is of a wrong type -func GetIPFromDNSRecord(r dns.RR) net.IP { - switch addr := r.(type) { +// IPFromRR returns the IP address from rr if any. +func IPFromRR(rr dns.RR) (ip net.IP) { + switch rr := rr.(type) { case *dns.A: - return addr.A.To4() - + ip = rr.A.To4() case *dns.AAAA: - return addr.AAAA + ip = rr.AAAA + default: + // Go on. } - return nil + return ip } -// ContainsIP checks if the specified IP is in the array -func ContainsIP(ips []net.IP, ip net.IP) bool { - for _, i := range ips { - if i.Equal(ip) { +// ContainsIP returns true if any of nets contains ip. +func ContainsIP(nets []*net.IPNet, ip net.IP) (ok bool) { + if netutil.ValidateIP(ip) != nil { + return false + } + + for _, n := range nets { + if n.Contains(ip) { return true } } @@ -59,7 +55,8 @@ func AppendIPAddrs(ipAddrs *[]net.IPAddr, answers []dns.RR) { // SortIPAddrs sorts the specified IP addresses array // IPv4 addresses go first, then IPv6 addresses func SortIPAddrs(ipAddrs []net.IPAddr) []net.IPAddr { - if len(ipAddrs) < 2 { + l := len(ipAddrs) + if l <= 1 { return ipAddrs } @@ -85,13 +82,13 @@ func SortIPAddrs(ipAddrs []net.IPAddr) []net.IPAddr { return ipAddrs } -func compareIPAddrs(left, right net.IPAddr) int { - l4 := left.IP.To4() - r4 := right.IP.To4() +func compareIPAddrs(a, b net.IPAddr) int { + l4 := a.IP.To4() + r4 := b.IP.To4() if l4 != nil && r4 == nil { return -1 // IPv4 addresses first } else if l4 == nil && r4 != nil { return 1 // IPv4 addresses first } - return bytes.Compare(left.IP, right.IP) + return bytes.Compare(a.IP, b.IP) } diff --git a/proxyutil/helpers_test.go b/proxyutil/helpers_test.go index 788beee08..646f1e30d 100644 --- a/proxyutil/helpers_test.go +++ b/proxyutil/helpers_test.go @@ -23,19 +23,69 @@ func TestSortIPAddrs(t *testing.T) { } func TestContainsIP(t *testing.T) { - ips := []net.IP{} - ips = append(ips, net.ParseIP("94.140.14.15")) - ips = append(ips, net.ParseIP("2a10:50c0::bad1:ff")) + nets := []*net.IPNet{{ + // IPv4. + IP: net.IP{1, 2, 3, 0}, + Mask: net.IPv4Mask(255, 255, 255, 0), + }, { + // IPv6 from IPv4. + IP: net.IPv4(1, 2, 4, 0), + Mask: net.CIDRMask(16, 32), + }, { + // IPv6. + IP: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0}, + Mask: net.CIDRMask(120, net.IPv6len*8), + }} - ip := net.ParseIP("94.140.14.15") - assert.True(t, ContainsIP(ips, ip)) + testCases := []struct { + name string + want assert.BoolAssertionFunc + ip net.IP + }{{ + name: "ipv4_yes", + want: assert.True, + ip: net.IP{1, 2, 3, 255}, + }, { + name: "ipv4_6_yes", + want: assert.True, + ip: net.IPv4(1, 2, 4, 254), + }, { + name: "ipv6_yes", + want: assert.True, + ip: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, { + name: "ipv6_4_yes", + want: assert.True, + ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 1, 2, 3, 0}, + }, { + name: "ipv4_no", + want: assert.False, + ip: net.IP{2, 1, 3, 255}, + }, { + name: "ipv4_6_no", + want: assert.False, + ip: net.IPv4(2, 1, 4, 254), + }, { + name: "ipv6_no", + want: assert.False, + ip: net.IP{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 16, 15}, + }, { + name: "ipv6_4_no", + want: assert.False, + ip: net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 2, 1, 4, 0}, + }, { + name: "nil_no", + want: assert.False, + ip: nil, + }, { + name: "bad_ip", + want: assert.False, + ip: net.IP{42}, + }} - ip = net.ParseIP("2a10:50c0::bad1:ff") - assert.True(t, ContainsIP(ips, ip)) - - ip = net.ParseIP("2a10:50c0::bad1:ff1") - assert.False(t, ContainsIP(ips, ip)) - - ip = net.ParseIP("127.0.0.1") - assert.False(t, ContainsIP(ips, ip)) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tc.want(t, ContainsIP(nets, tc.ip)) + }) + } }