diff --git a/proxy/cache.go b/proxy/cache.go index d23131594..053c84948 100644 --- a/proxy/cache.go +++ b/proxy/cache.go @@ -152,16 +152,7 @@ func (p *Proxy) initCache() { c.initLazyWithSubnet() } - p.shortFlighter = newOptimisticResolver( - p.replyFromUpstream, - p.cacheResp, - c.del, - ) - p.shortFlighterWithSubnet = newOptimisticResolver( - p.replyFromUpstream, - p.cacheResp, - c.delWithSubnet, - ) + p.shortFlighter = newOptimisticResolver(p) } // get returns cached item for the req if it's found. expired is true if the @@ -528,25 +519,3 @@ func filterMsg(dst, m *dns.Msg, ad, do bool, ttl uint32) { dst.Ns = filterRRSlice(m.Ns, do, ttl, dns.TypeNone) dst.Extra = filterRRSlice(m.Extra, do, ttl, dns.TypeNone) } - -func (c *cache) del(key []byte) { - c.itemsLock.RLock() - defer c.itemsLock.RUnlock() - - if c.items == nil { - return - } - - c.items.Del(key) -} - -func (c *cache) delWithSubnet(key []byte) { - c.itemsWithSubnetLock.RLock() - defer c.itemsWithSubnetLock.RUnlock() - - if c.itemsWithSubnet == nil { - return - } - - c.itemsWithSubnet.Del(key) -} diff --git a/proxy/optimisticresolver.go b/proxy/optimisticresolver.go index 2fb013445..1d21d943d 100644 --- a/proxy/optimisticresolver.go +++ b/proxy/optimisticresolver.go @@ -7,35 +7,33 @@ import ( "github.com/AdguardTeam/golibs/log" ) -// resolveFunc is the signature of a method to resolve expired cached requests. -// This is exactly the signature of Proxy.replyFromUpstream. -type resolveFunc func(dctx *DNSContext) (ok bool, err error) +// cachingResolver is the DNS resolver that is also able to cache responses. +type cachingResolver interface { + // replyFromUpstream returns true if the request from dctx is successfully + // resolved and the response may be cached. + // + // TODO(e.burkov): Find out when ok can be false with nil err. + replyFromUpstream(dctx *DNSContext) (ok bool, err error) -// setFunc is the signature of a method to cache response. This is exactly the -// signature of Proxy.setInCache method. -type setFunc func(dctx *DNSContext) + // cacheResp caches the response from dctx. + cacheResp(dctx *DNSContext) +} -// deleteFunc is the signature of a method to remove the response from cache. -type deleteFunc func(key []byte) +// type check +var _ cachingResolver = (*Proxy)(nil) // optimisticResolver is used to eventually resolve expired cached requests. -// -// TODO(e.burkov): Think about generalizing all function-fields into a single -// interface. type optimisticResolver struct { - reqs *sync.Map - resolve resolveFunc - set setFunc - delete deleteFunc + reqs *sync.Map + cr cachingResolver } // newOptimisticResolver returns the new resolver for expired cached requests. -func newOptimisticResolver(rf resolveFunc, sf setFunc, df deleteFunc) (s *optimisticResolver) { +// cr must not be nil. +func newOptimisticResolver(cr cachingResolver) (s *optimisticResolver) { return &optimisticResolver{ - reqs: &sync.Map{}, - resolve: rf, - set: sf, - delete: df, + reqs: &sync.Map{}, + cr: cr, } } @@ -55,14 +53,12 @@ func (s *optimisticResolver) ResolveOnce(dctx *DNSContext, key []byte) { } defer s.reqs.Delete(keyHexed) - ok, err := s.resolve(dctx) + ok, err := s.cr.replyFromUpstream(dctx) if err != nil { log.Debug("resolving request for optimistic cache: %s", err) } if ok { - s.set(dctx) - } else { - s.delete(key) + s.cr.cacheResp(dctx) } } diff --git a/proxy/optimisticresolver_test.go b/proxy/optimisticresolver_test.go index 825f5fb60..44584497c 100644 --- a/proxy/optimisticresolver_test.go +++ b/proxy/optimisticresolver_test.go @@ -10,26 +10,45 @@ import ( "github.com/stretchr/testify/assert" ) -func TestOptimisticResolver_ResolveOnce(t *testing.T) { - in, out := make(chan unit), make(chan unit) - var timesResolved int - testResolveFunc := func(_ *DNSContext) (ok bool, err error) { - timesResolved++ +// testCachingResolver is a stub implementation of the cachingResolver interface +// to simplify testing. +type testCachingResolver struct { + onReplyFromUpstream func(dctx *DNSContext) (ok bool, err error) + onCacheResp func(dctx *DNSContext) +} - return true, nil - } +// replyFromUpstream implements the cachingResolver interface for +// *testCachingResolver. +func (tcr *testCachingResolver) replyFromUpstream(dctx *DNSContext) (ok bool, err error) { + return tcr.onReplyFromUpstream(dctx) +} - var timesSet int - testSetFunc := func(_ *DNSContext) { - timesSet++ +// cacheResp implements the cachingResolver interface for *testCachingResolver. +func (tcr *testCachingResolver) cacheResp(dctx *DNSContext) { + tcr.onCacheResp(dctx) +} - // Pass the signal to begin running secondary goroutines. - out <- unit{} - // Block until all the secondary goroutines finish. - <-in +func TestOptimisticResolver_ResolveOnce(t *testing.T) { + in, out := make(chan unit), make(chan unit) + var timesResolved, timesSet int + + tcr := &testCachingResolver{ + onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { + timesResolved++ + + return true, nil + }, + onCacheResp: func(_ *DNSContext) { + timesSet++ + + // Pass the signal to begin running secondary goroutines. + out <- unit{} + // Block until all the secondary goroutines finish. + <-in + }, } - s := newOptimisticResolver(testResolveFunc, testSetFunc, nil) + s := newOptimisticResolver(tcr) sameKey := []byte{1, 2, 3} // Start the primary goroutine. @@ -61,8 +80,6 @@ func TestOptimisticResolver_ResolveOnce(t *testing.T) { func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) { key := []byte{1, 2, 3} - noopSetFunc := func(_ *DNSContext) {} - t.Run("error", func(t *testing.T) { logOutput := &bytes.Buffer{} @@ -76,29 +93,23 @@ func TestOptimisticResolver_ResolveOnce_unsuccessful(t *testing.T) { }) const rerr errors.Error = "sample resolving error" - testResolveFunc := func(_ *DNSContext) (ok bool, err error) { - return true, rerr - } - - s := newOptimisticResolver(testResolveFunc, noopSetFunc, nil) + s := newOptimisticResolver(&testCachingResolver{ + onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return true, rerr }, + onCacheResp: func(_ *DNSContext) {}, + }) s.ResolveOnce(nil, key) assert.Contains(t, logOutput.String(), rerr.Error()) }) t.Run("not_ok", func(t *testing.T) { - testResolveFunc := func(_ *DNSContext) (ok bool, err error) { - return false, nil - } - - var deleteCalled bool - testDeleteFunc := func(_ []byte) { - deleteCalled = true - } - - s := newOptimisticResolver(testResolveFunc, noopSetFunc, testDeleteFunc) + cached := false + s := newOptimisticResolver(&testCachingResolver{ + onReplyFromUpstream: func(_ *DNSContext) (ok bool, err error) { return false, nil }, + onCacheResp: func(_ *DNSContext) { cached = true }, + }) s.ResolveOnce(nil, key) - assert.True(t, deleteCalled) + assert.False(t, cached) }) } diff --git a/proxy/proxy.go b/proxy/proxy.go index a68dd4e69..93fd00d83 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -104,10 +104,6 @@ type Proxy struct { // shortFlighter is used to resolve the expired cached requests without // repetitions. shortFlighter *optimisticResolver - // shortFlighterWithSubnet is used to resolve the expired cached - // requests making sure that only one request for each cache item is - // performed at a time. - shortFlighterWithSubnet *optimisticResolver // FastestAddr module // -- @@ -388,8 +384,7 @@ func (p *Proxy) Addr(proto Proto) net.Addr { } } -// replyFromUpstream tries to resolve the request and caches it if cacheWorks is -// true. +// replyFromUpstream tries to resolve the request. func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) { req := d.Req host := req.Question[0].Name @@ -410,7 +405,7 @@ func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) { var u upstream.Upstream reply, u, err = p.exchange(req, upstreams) if p.isNAT64PrefixAvailable() && p.isEmptyAAAAResponse(reply, req) { - log.Tracef("Received an empty AAAA response, checking DNS64") + log.Tracef("received an empty AAAA response, checking DNS64") reply, u, err = p.checkDNS64(req, reply, upstreams) } else if p.isBogusNXDomain(reply) { log.Tracef("response ip is contained in bogus-nxdomain list") @@ -420,7 +415,7 @@ func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) { log.Tracef("RTT: %s", time.Since(start)) if err != nil && p.Fallbacks != nil { - log.Tracef("Using the fallback upstream due to %s", err) + log.Tracef("using the fallback upstream due to %s", err) reply, u, err = upstream.ExchangeParallel(p.Fallbacks, req) } @@ -432,9 +427,9 @@ func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) { d.Upstream = u p.setMinMaxTTL(reply) - // Explicitly construct the question section since some - // upstreams may respond with invalidly constructed messages - // which cause out-of-range panics afterwards. + // Explicitly construct the question section since some upstreams may + // respond with invalidly constructed messages which cause out-of-range + // panics afterwards. // // See https://github.com/AdguardTeam/AdGuardHome/issues/3551. if len(req.Question) > 0 && len(reply.Question) == 0 { diff --git a/proxy/proxy_cache.go b/proxy/proxy_cache.go index d1582b2c8..44023cea0 100644 --- a/proxy/proxy_cache.go +++ b/proxy/proxy_cache.go @@ -12,14 +12,13 @@ func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) { hitMsg := "serving cached response" var expired bool - var withSubnet bool var key []byte if !p.Config.EnableEDNSClientSubnet { ci, expired, key = p.cache.get(d.Req) - } else if withSubnet = d.ecsReqMask != 0; withSubnet { + } else if d.ecsReqMask != 0 { ci, expired, key = p.cache.getWithSubnet(d.Req, d.ecsReqIP, d.ecsReqMask) hitMsg = "serving response from subnet cache" - } else if d.ecsReqMask == 0 { + } else { ci, expired, key = p.cache.get(d.Req) hitMsg = "serving response from general cache" } @@ -49,11 +48,7 @@ func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) { minCtxClone.Req = req } - if !withSubnet { - go p.shortFlighter.ResolveOnce(minCtxClone, key) - } else { - go p.shortFlighterWithSubnet.ResolveOnce(minCtxClone, key) - } + go p.shortFlighter.ResolveOnce(minCtxClone, key) } return hit diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index f95bf4635..23fed092e 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -1324,96 +1324,69 @@ func TestProxy_Resolve_withOptimisticResolver(t *testing.T) { }, } - testFunc := func(t *testing.T, responsed bool) (firstCtx *DNSContext) { - p.initCache() - out, in := make(chan unit), make(chan unit) - p.shortFlighter.resolve = func(dctx *DNSContext) (ok bool, err error) { + p.initCache() + out, in := make(chan unit), make(chan unit) + p.shortFlighter.cr = &testCachingResolver{ + onReplyFromUpstream: func(dctx *DNSContext) (ok bool, err error) { dctx.Res = buildResp(dctx.Req, nonOptimisticTTL) - return responsed, nil - } - p.shortFlighter.set = func(dctx *DNSContext) { - defer func() { - out <- unit{} - }() - + return true, nil + }, + onCacheResp: func(dctx *DNSContext) { // Report adding to cache is in process. out <- unit{} // Wait for tests to finish. <-in p.cacheResp(dctx) - } - p.shortFlighter.delete = func(k []byte) { - defer func() { - out <- unit{} - }() - // Report deleting from cache is in process. + // Report adding tocache is finished. out <- unit{} - // Wait for tests to finish. - <-in - - p.cache.del(k) - } - - // Two different contexts are made to emulate two different - // requests with the same question section. - var secondCtx *DNSContext - firstCtx, secondCtx = buildCtx(), buildCtx() - - // Add expired response into cache. - req := firstCtx.Req - key := msgToKey(req) - data := (&cacheItem{ - m: buildResp(req, 0), - u: testUpsAddr, - }).pack() - items := glcache.New(glcache.Config{ - EnableLRU: true, - }) - items.Set(key, data) - p.cache.items = items - - err := p.Resolve(firstCtx) - require.NoError(t, err) - require.Len(t, firstCtx.Res.Answer, 1) - - assert.EqualValues(t, optimisticTTL, firstCtx.Res.Answer[0].Header().Ttl) + }, + } - // Wait for optimisticResolver to reach the tested function. - <-out + // Two different contexts are made to emulate two different requests + // with the same question section. + firstCtx, secondCtx := buildCtx(), buildCtx() - err = p.Resolve(secondCtx) - require.NoError(t, err) - require.Len(t, secondCtx.Res.Answer, 1) + // Add expired response into cache. + req := firstCtx.Req + key := msgToKey(req) + data := (&cacheItem{ + m: buildResp(req, 0), + u: testUpsAddr, + }).pack() + items := glcache.New(glcache.Config{ + EnableLRU: true, + }) + items.Set(key, data) + p.cache.items = items - assert.EqualValues(t, optimisticTTL, secondCtx.Res.Answer[0].Header().Ttl) + err := p.Resolve(firstCtx) + require.NoError(t, err) + require.Len(t, firstCtx.Res.Answer, 1) - // Continue and wait for it to finish. - in <- unit{} - <-out + assert.EqualValues(t, optimisticTTL, firstCtx.Res.Answer[0].Header().Ttl) - return firstCtx - } + // Wait for optimisticResolver to reach the tested function. + <-out - t.Run("successful", func(t *testing.T) { - firstCtx := testFunc(t, true) + err = p.Resolve(secondCtx) + require.NoError(t, err) + require.Len(t, secondCtx.Res.Answer, 1) - // Should be served from cache. - data := p.cache.items.Get(msgToKey(firstCtx.Req)) - unpacked, expired := p.cache.unpackItem(data, firstCtx.Req) - require.False(t, expired) - require.NotNil(t, unpacked) - require.Len(t, unpacked.m.Answer, 1) + assert.EqualValues(t, optimisticTTL, secondCtx.Res.Answer[0].Header().Ttl) - assert.EqualValues(t, nonOptimisticTTL, unpacked.m.Answer[0].Header().Ttl) - }) + // Continue and wait for it to finish. + in <- unit{} + <-out - t.Run("unsuccessful", func(t *testing.T) { - firstCtx := testFunc(t, false) + // Should be served from cache. + data = p.cache.items.Get(msgToKey(firstCtx.Req)) + unpacked, expired := p.cache.unpackItem(data, firstCtx.Req) + require.False(t, expired) + require.NotNil(t, unpacked) + require.Len(t, unpacked.m.Answer, 1) - // Should be removed from cache. - assert.Nil(t, p.cache.items.Get(msgToKey(firstCtx.Req))) - }) + assert.EqualValues(t, nonOptimisticTTL, unpacked.m.Answer[0].Header().Ttl) }