From a411c5feee6585ef4ee5f3e016a247b10942d3ec Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Fri, 12 Apr 2024 19:08:02 +0000 Subject: [PATCH 01/24] reloadCacheEntry removes edns and publishes to redis --- resolver/caching_resolver.go | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index e2a052f01..1a4fc39dc 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -111,21 +111,33 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) req := newRequest(dns.Fqdn(domainName), qType) + response, err := r.next.Resolve(ctx, req) + if err != nil { + util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err) - if err == nil { - if response.Res.Rcode == dns.RcodeSuccess { - packed, err := response.Res.Pack() - if err != nil { - logger.Error("unable to pack response", err) + return nil, 0 + } - return nil, 0 - } + if response.Res.Rcode == dns.RcodeSuccess { + respCopy := response.Res.Copy() - return &packed, r.adjustTTLs(response.Res.Answer) + // don't cache any EDNS OPT records + util.RemoveEdns0Record(respCopy) + + packed, err := respCopy.Pack() + if err != nil { + logger.Error("unable to pack response", err) + + return nil, 0 } - } else { - util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err) + + if r.redisClient != nil { + res := *respCopy + r.redisClient.PublishCache(cacheKey, &res) + } + + return &packed, r.adjustTTLs(response.Res.Answer) } return nil, 0 From 6a899075b665c79541c937b95a32aba8472e92d6 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Fri, 12 Apr 2024 23:43:13 +0200 Subject: [PATCH 02/24] Update resolver/caching_resolver.go Co-authored-by: ThinkChaos --- resolver/caching_resolver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 1a4fc39dc..757f03b9f 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -114,7 +114,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) response, err := r.next.Resolve(ctx, req) if err != nil { - util.LogOnError(ctx, fmt.Sprintf("can't prefetch '%s' ", domainName), err) + logger.WithError(err).WithField("domain", domainName).Warn("cache prefetch failed") return nil, 0 } From 86071445e5ad28f9cf8bf60466bb77867d5e0bd6 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sat, 13 Apr 2024 00:09:29 +0000 Subject: [PATCH 03/24] added transformAndPublish --- resolver/caching_resolver.go | 85 +++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 41 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 757f03b9f..0cab64b0d 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -2,7 +2,6 @@ package resolver import ( "context" - "fmt" "math" "sync/atomic" "time" @@ -119,28 +118,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) return nil, 0 } - if response.Res.Rcode == dns.RcodeSuccess { - respCopy := response.Res.Copy() - - // don't cache any EDNS OPT records - util.RemoveEdns0Record(respCopy) - - packed, err := respCopy.Pack() - if err != nil { - logger.Error("unable to pack response", err) - - return nil, 0 - } - - if r.redisClient != nil { - res := *respCopy - r.redisClient.PublishCache(cacheKey, &res) - } - - return &packed, r.adjustTTLs(response.Res.Answer) - } - - return nil, 0 + return r.transformAndPublish(ctx, cacheKey, response, true) } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -151,8 +129,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { case rc := <-r.redisClient.CacheChannel: if rc != nil { logger.Debug("Received key from redis: ", rc.Key) - ttl := r.adjustTTLs(rc.Response.Res.Answer) - r.putInCache(ctx, rc.Key, rc.Response, ttl, false) + r.putInCache(ctx, rc.Key, rc.Response, false) } case <-ctx.Done(): @@ -205,8 +182,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( response, err = r.next.Resolve(ctx, request) if err == nil { - cacheTTL := r.adjustTTLs(response.Res.Answer) - r.putInCache(ctx, cacheKey, response, cacheTTL, true) + r.putInCache(ctx, cacheKey, response, true) } } @@ -262,33 +238,60 @@ func isResponseCacheable(msg *dns.Msg) bool { return !msg.Truncated && !msg.CheckingDisabled } -func (r *CachingResolver) putInCache( - ctx context.Context, cacheKey string, response *model.Response, ttl time.Duration, publish bool, -) { +// transformAndPublish transforms the response to a byte array and publishes it to redis if publish is true +// and redis is enabled. Returns the byte array and the TTL of the response +func (r *CachingResolver) transformAndPublish(ctx context.Context, cacheKey string, + response *model.Response, publish bool, +) (*[]byte, time.Duration) { + if response.Res.Rcode == dns.RcodeSuccess && !isResponseCacheable(response.Res) { + return nil, 0 + } + + _, domainName := util.ExtractCacheKey(cacheKey) + + _, logger := r.log(ctx) + respCopy := response.Res.Copy() // don't cache any EDNS OPT records util.RemoveEdns0Record(respCopy) packed, err := respCopy.Pack() - util.LogOnError(ctx, "error on packing", err) - - if err == nil { - if response.Res.Rcode == dns.RcodeSuccess && isResponseCacheable(response.Res) { - // put value into cache - r.resultCache.Put(cacheKey, &packed, ttl) - } else if response.Res.Rcode == dns.RcodeNameError { - if r.cfg.CacheTimeNegative.IsAboveZero() { - // put negative cache if result code is NXDOMAIN - r.resultCache.Put(cacheKey, &packed, r.cfg.CacheTimeNegative.ToDuration()) - } + if err != nil { + logger.WithError(err).WithField("domain", domainName).Warn("cache prefetch failed") + + return nil, 0 + } + + ttl := time.Duration(0) + + if response.Res.Rcode == dns.RcodeSuccess { + ttl = r.adjustTTLs(response.Res.Answer) + } else if response.Res.Rcode == dns.RcodeNameError { + if r.cfg.CacheTimeNegative.IsAboveZero() { + ttl = r.cfg.CacheTimeNegative.ToDuration() } } if publish && r.redisClient != nil { res := *respCopy + for _, rr := range res.Answer { + rr.Header().Ttl = uint32(ttl.Seconds()) + } + r.redisClient.PublishCache(cacheKey, &res) } + + return &packed, ttl +} + +func (r *CachingResolver) putInCache( + ctx context.Context, cacheKey string, response *model.Response, publish bool, +) { + res, ttl := r.transformAndPublish(ctx, cacheKey, response, publish) + if res != nil { + r.resultCache.Put(cacheKey, res, ttl) + } } // adjustTTLs calculates and returns the min TTL (considers also the min and max cache time) From 43e59a48e3ac64ba73852112ffcec714e181b909 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sun, 14 Apr 2024 13:31:36 +0000 Subject: [PATCH 04/24] wip --- resolver/caching_resolver.go | 206 ++++++++++++++++------------------- util/dns.go | 68 ++++++++++++ 2 files changed, 163 insertions(+), 111 deletions(-) create mode 100644 util/dns.go diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 0cab64b0d..c6172e60c 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -2,8 +2,6 @@ package resolver import ( "context" - "math" - "sync/atomic" "time" "github.com/0xERR0R/blocky/cache/expirationcache" @@ -17,7 +15,11 @@ import ( "github.com/sirupsen/logrus" ) -const defaultCachingCleanUpInterval = 5 * time.Second +const ( + defaultCachingCleanUpInterval = 5 * time.Second + // noCacheTTL indicates that a response should not be cached + noCacheTTL = time.Duration(-1) +) // CachingResolver caches answers from dns queries with their TTL time, // to avoid external resolver calls for recurrent queries @@ -106,19 +108,36 @@ func configureCaches(ctx context.Context, c *CachingResolver, cfg *config.Cachin func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) (*[]byte, time.Duration) { qType, domainName := util.ExtractCacheKey(cacheKey) ctx, logger := r.log(ctx) + logger = logger.WithField("domain", util.Obfuscate(domainName)) - logger.Debugf("prefetching '%s' (%s)", util.Obfuscate(domainName), qType) + logger.Debugf("prefetching %s", qType) req := newRequest(dns.Fqdn(domainName), qType) response, err := r.next.Resolve(ctx, req) if err != nil { - logger.WithError(err).WithField("domain", domainName).Warn("cache prefetch failed") + logger.WithError(err).Warn("cache prefetch failed") + + return nil, 0 + } + cacheCopy, ttl := r.createCacheEntry(logger, response.Res) + if cacheCopy == nil || !cacheableTTL(ttl) { return nil, 0 } - return r.transformAndPublish(ctx, cacheKey, response, true) + packed, err := cacheCopy.Pack() + if err != nil { + logger.WithError(err).WithError(err).Warn("response packing failed") + + return nil, 0 + } + + if r.redisClient != nil { + r.redisClient.PublishCache(cacheKey, cacheCopy) + } + + return &packed, ttl } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -128,8 +147,13 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { select { case rc := <-r.redisClient.CacheChannel: if rc != nil { - logger.Debug("Received key from redis: ", rc.Key) - r.putInCache(ctx, rc.Key, rc.Response, false) + _, domain := util.ExtractCacheKey(rc.Key) + + dlogger := logger.WithField("domain", util.Obfuscate(domain)) + + dlogger.Debug("received from redis") + + r.putInCache(dlogger, rc.Key, rc.Response) } case <-ctx.Done(): @@ -161,62 +185,56 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( cacheKey := util.GenerateCacheKey(dns.Type(question.Qtype), domain) logger := logger.WithField("domain", util.Obfuscate(domain)) - val, ttl := r.getFromCache(logger, cacheKey) + cacheEntry := r.getFromCache(logger, cacheKey) - if val != nil { + if cacheEntry != nil { logger.Debug("domain is cached") - val.SetRcode(request.Req, val.Rcode) + cacheEntry.SetRcode(request.Req, cacheEntry.Rcode) - // Adjust TTL - setTTLInCachedResponse(val, ttl) - - if val.Rcode == dns.RcodeSuccess { - return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil + if cacheEntry.Rcode == dns.RcodeSuccess { + return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED"}, nil } - return &model.Response{Res: val, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil + return &model.Response{Res: cacheEntry, RType: model.ResponseTypeCACHED, Reason: "CACHED NEGATIVE"}, nil } logger.WithField("next_resolver", Name(r.next)).Trace("not in cache: go to next resolver") - response, err = r.next.Resolve(ctx, request) + response, err = r.next.Resolve(ctx, request) if err == nil { - r.putInCache(ctx, cacheKey, response, true) + ttl := r.modifyResponseTTL(response.Res) + if cacheableTTL(ttl) { + cacheCopy := r.putInCache(logger, cacheKey, response) + if cacheCopy != nil && r.redisClient != nil { + r.redisClient.PublishCache(cacheKey, cacheCopy) + } + } } } return response, err } -func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) (*dns.Msg, time.Duration) { - val, ttl := r.resultCache.Get(key) - if val == nil { - return nil, 0 +func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) *dns.Msg { + raw, ttl := r.resultCache.Get(key) + if raw == nil { + return nil } res := new(dns.Msg) - err := res.Unpack(*val) + err := res.Unpack(*raw) if err != nil { logger.Error("can't unpack cached entry. Cache malformed?", err) - return nil, 0 + return nil } - return res, ttl -} - -func setTTLInCachedResponse(resp *dns.Msg, ttl time.Duration) { - minTTL := uint32(math.MaxInt32) - // find smallest TTL first - for _, rr := range resp.Answer { - minTTL = min(minTTL, rr.Header().Ttl) - } + // Adjust TTL + util.AdjustAnswerTTL(res, uint32(ttl.Seconds())) - for _, rr := range resp.Answer { - rr.Header().Ttl = rr.Header().Ttl - minTTL + uint32(ttl.Seconds()) - } + return res } // isRequestCacheable returns true if the request should be cached @@ -232,99 +250,61 @@ func isRequestCacheable(request *model.Request) bool { return true } -// isResponseCacheable returns true if the response is not truncated and its CD flag isn't set. -func isResponseCacheable(msg *dns.Msg) bool { - // we don't cache truncated responses and responses with CD flag - return !msg.Truncated && !msg.CheckingDisabled -} - -// transformAndPublish transforms the response to a byte array and publishes it to redis if publish is true -// and redis is enabled. Returns the byte array and the TTL of the response -func (r *CachingResolver) transformAndPublish(ctx context.Context, cacheKey string, - response *model.Response, publish bool, -) (*[]byte, time.Duration) { - if response.Res.Rcode == dns.RcodeSuccess && !isResponseCacheable(response.Res) { - return nil, 0 +func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, response *model.Response) *dns.Msg { + cacheCopy, ttl := r.createCacheEntry(logger, response.Res) + if cacheCopy == nil || !cacheableTTL(ttl) { + return nil } - _, domainName := util.ExtractCacheKey(cacheKey) - - _, logger := r.log(ctx) - - respCopy := response.Res.Copy() - - // don't cache any EDNS OPT records - util.RemoveEdns0Record(respCopy) - - packed, err := respCopy.Pack() + packed, err := cacheCopy.Pack() if err != nil { - logger.WithError(err).WithField("domain", domainName).Warn("cache prefetch failed") + logger.WithError(err).Warn("response packing failed") - return nil, 0 + return nil } - ttl := time.Duration(0) - - if response.Res.Rcode == dns.RcodeSuccess { - ttl = r.adjustTTLs(response.Res.Answer) - } else if response.Res.Rcode == dns.RcodeNameError { - if r.cfg.CacheTimeNegative.IsAboveZero() { - ttl = r.cfg.CacheTimeNegative.ToDuration() - } - } + r.resultCache.Put(cacheKey, &packed, ttl) - if publish && r.redisClient != nil { - res := *respCopy - for _, rr := range res.Answer { - rr.Header().Ttl = uint32(ttl.Seconds()) - } + return cacheCopy +} - r.redisClient.PublishCache(cacheKey, &res) +func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) time.Duration { + // if response is empty or negative, return negative cache time from config + if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError { + return r.cfg.CacheTimeNegative.ToDuration() } - return &packed, ttl -} + // if response is truncated or CD flag is set, return noCacheTTL since we don't cache these responses + if response.Truncated || response.CheckingDisabled { + return noCacheTTL + } -func (r *CachingResolver) putInCache( - ctx context.Context, cacheKey string, response *model.Response, publish bool, -) { - res, ttl := r.transformAndPublish(ctx, cacheKey, response, publish) - if res != nil { - r.resultCache.Put(cacheKey, res, ttl) + // if response is not successful, return noCacheTTL since we don't cache these responses + if response.Rcode != dns.RcodeSuccess { + return noCacheTTL } -} -// adjustTTLs calculates and returns the min TTL (considers also the min and max cache time) -// for all records from answer or a negative cache time for empty answer -// adjust the TTL in the answer header accordingly -func (r *CachingResolver) adjustTTLs(answer []dns.RR) (ttl time.Duration) { - minTTL := uint32(math.MaxInt32) + // adjust TTLs of all answers to match the configured min and max caching times + util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime.SecondsU32(), r.cfg.MaxCachingTime.SecondsU32()) - if len(answer) == 0 { - return r.cfg.CacheTimeNegative.ToDuration() - } + return time.Duration(util.GetAnswerMinTTL(response)) * time.Second +} - for _, a := range answer { - // if TTL < mitTTL -> adjust the value, set minTTL - if r.cfg.MinCachingTime.IsAboveZero() { - if atomic.LoadUint32(&a.Header().Ttl) < r.cfg.MinCachingTime.SecondsU32() { - atomic.StoreUint32(&a.Header().Ttl, r.cfg.MinCachingTime.SecondsU32()) - } - } +func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg, +) (*dns.Msg, time.Duration) { + response := input.Copy() - if r.cfg.MaxCachingTime.IsAboveZero() { - if atomic.LoadUint32(&a.Header().Ttl) > r.cfg.MaxCachingTime.SecondsU32() { - atomic.StoreUint32(&a.Header().Ttl, r.cfg.MaxCachingTime.SecondsU32()) - } - } + ttl := r.modifyResponseTTL(response) + if !cacheableTTL(ttl) { + logger.Debug("response is not cacheable") - headerTTL := atomic.LoadUint32(&a.Header().Ttl) - if minTTL > headerTTL { - minTTL = headerTTL - } + return nil, noCacheTTL } - return time.Duration(minTTL) * time.Second + // don't cache any EDNS OPT records + util.RemoveEdns0Record(response) + + return response, ttl } func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) { @@ -339,3 +319,7 @@ func (r *CachingResolver) FlushCaches(ctx context.Context) { logger.Debug("flush caches") r.resultCache.Clear() } + +func cacheableTTL(ttl time.Duration) bool { + return ttl > 0 +} diff --git a/util/dns.go b/util/dns.go new file mode 100644 index 000000000..9f5917747 --- /dev/null +++ b/util/dns.go @@ -0,0 +1,68 @@ +package util + +import ( + "math" + "sync/atomic" + + "github.com/miekg/dns" +) + +// SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to +// the minimum TTL. +func SetAnswerMinTTL(msg *dns.Msg, minTTL uint32) { + for _, answer := range msg.Answer { + if atomic.LoadUint32(&answer.Header().Ttl) < minTTL { + atomic.StoreUint32(&answer.Header().Ttl, minTTL) + } + } +} + +// SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL +// to the maximum TTL. +func SetAnswerMaxTTL(msg *dns.Msg, maxTTL uint32) { + for _, answer := range msg.Answer { + if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL && maxTTL != 0 { + atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + } + } +} + +// SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL +// to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL. +func SetAnswerMinMaxTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32) { + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + if headerTTL < minTTL { + atomic.StoreUint32(&answer.Header().Ttl, minTTL) + } else if headerTTL > maxTTL && maxTTL != 0 { + atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + } + } +} + +// GetMinAnswerTTL returns the lowest TTL of all answers in the message. +func GetAnswerMinTTL(msg *dns.Msg) uint32 { + var minTTL atomic.Uint32 + // initialize minTTL with the maximum value of uint32 + minTTL.Store(math.MaxUint32) + + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + if headerTTL < minTTL.Load() { + minTTL.Store(headerTTL) + } + } + + return minTTL.Load() +} + +// AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL +// and the answer's TTL plus the specified adjustment. +func AdjustAnswerTTL(msg *dns.Msg, adjustment uint32) { + minTTL := GetAnswerMinTTL(msg) + + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustment) + } +} From 5eb351f23c0bccd6288f410baf284f43f1d8759c Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Tue, 16 Apr 2024 18:09:49 +0000 Subject: [PATCH 05/24] dns util rework --- util/dns.go | 77 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/util/dns.go b/util/dns.go index 9f5917747..552b3e73f 100644 --- a/util/dns.go +++ b/util/dns.go @@ -2,14 +2,76 @@ package util import ( "math" + "strconv" "sync/atomic" + "time" "github.com/miekg/dns" ) +// ttlInput is the input type for TTL values and consists of the following types: +// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, string, time.Duration +type ttlInput interface { + int | int8 | int16 | int32 | int64 | uint | uint8 | uint32 | uint64 | string | time.Duration +} + +// ToTTL converts the input to a TTL of seconds as uint32. +func ToTTL[T ttlInput](input T) uint32 { + // use int64 as the intermediate type + res := int64(0) + + switch typedInput := any(input).(type) { + case string: + if seconds, err := strconv.Atoi(typedInput); err == nil { + res = int64(seconds) + } else { + if duration, err := time.ParseDuration(typedInput); err == nil { + res = int64(duration.Seconds()) + } + } + case time.Duration: + res = int64(typedInput.Seconds()) + case int: + res = int64(typedInput) + case int8: + res = int64(typedInput) + case int16: + res = int64(typedInput) + case int32: + res = int64(typedInput) + case int64: + res = typedInput + case uint: + res = int64(typedInput) + case uint8: + res = int64(typedInput) + case uint16: + res = int64(typedInput) + case uint32: + res = int64(typedInput) + case uint64: + res = int64(typedInput) + default: + panic("invalid TTL value input type") + } + + // check if the value is negative or greater than the maximum value of uint32 + if res < 0 { + // there is no negative TTL + return 0 + } else if res > math.MaxUint32 { + // since TTL is a 32-bit unsigned integer, the maximum value is math.MaxUint32 + return math.MaxUint32 + } + + // return the value as uint32 + return uint32(res) +} + // SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to // the minimum TTL. -func SetAnswerMinTTL(msg *dns.Msg, minTTL uint32) { +func SetAnswerMinTTL[T ttlInput](msg *dns.Msg, min T) { + minTTL := ToTTL(min) for _, answer := range msg.Answer { if atomic.LoadUint32(&answer.Header().Ttl) < minTTL { atomic.StoreUint32(&answer.Header().Ttl, minTTL) @@ -19,7 +81,8 @@ func SetAnswerMinTTL(msg *dns.Msg, minTTL uint32) { // SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL // to the maximum TTL. -func SetAnswerMaxTTL(msg *dns.Msg, maxTTL uint32) { +func SetAnswerMaxTTL[T ttlInput](msg *dns.Msg, max T) { + maxTTL := ToTTL(max) for _, answer := range msg.Answer { if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL && maxTTL != 0 { atomic.StoreUint32(&answer.Header().Ttl, maxTTL) @@ -29,7 +92,10 @@ func SetAnswerMaxTTL(msg *dns.Msg, maxTTL uint32) { // SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL // to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL. -func SetAnswerMinMaxTTL(msg *dns.Msg, minTTL uint32, maxTTL uint32) { +func SetAnswerMinMaxTTL[T ttlInput](msg *dns.Msg, min, max T) { + minTTL := ToTTL(min) + maxTTL := ToTTL(max) + for _, answer := range msg.Answer { headerTTL := atomic.LoadUint32(&answer.Header().Ttl) if headerTTL < minTTL { @@ -58,11 +124,12 @@ func GetAnswerMinTTL(msg *dns.Msg) uint32 { // AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL // and the answer's TTL plus the specified adjustment. -func AdjustAnswerTTL(msg *dns.Msg, adjustment uint32) { +func AdjustAnswerTTL[T ttlInput](msg *dns.Msg, adjustment T) { minTTL := GetAnswerMinTTL(msg) + adjustmentTTL := ToTTL(adjustment) for _, answer := range msg.Answer { headerTTL := atomic.LoadUint32(&answer.Header().Ttl) - atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustment) + atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL) } } From 0f6c84752947ded6fea3ff8750478d829236d6ef Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 17:12:00 +0000 Subject: [PATCH 06/24] fixed generics --- util/dns.go | 63 +++++++++++++++-------------------------------------- 1 file changed, 17 insertions(+), 46 deletions(-) diff --git a/util/dns.go b/util/dns.go index 552b3e73f..4fb45fee9 100644 --- a/util/dns.go +++ b/util/dns.go @@ -2,57 +2,28 @@ package util import ( "math" - "strconv" "sync/atomic" - "time" "github.com/miekg/dns" ) -// ttlInput is the input type for TTL values and consists of the following types: -// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, string, time.Duration -type ttlInput interface { - int | int8 | int16 | int32 | int64 | uint | uint8 | uint32 | uint64 | string | time.Duration +// TTLInput is the input type for TTL values and consists of the following underlying types: +// int, uint, uint32, int64 +type TTLInput interface { + ~int | ~uint | ~uint32 | ~int64 } // ToTTL converts the input to a TTL of seconds as uint32. -func ToTTL[T ttlInput](input T) uint32 { +// If the input is of underlying type time.Duration, the value is converted to seconds. +// If the input is negative, the TTL is set to 0. +// If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32. +func ToTTL[T TTLInput](input T) uint32 { // use int64 as the intermediate type - res := int64(0) - - switch typedInput := any(input).(type) { - case string: - if seconds, err := strconv.Atoi(typedInput); err == nil { - res = int64(seconds) - } else { - if duration, err := time.ParseDuration(typedInput); err == nil { - res = int64(duration.Seconds()) - } - } - case time.Duration: - res = int64(typedInput.Seconds()) - case int: - res = int64(typedInput) - case int8: - res = int64(typedInput) - case int16: - res = int64(typedInput) - case int32: - res = int64(typedInput) - case int64: - res = typedInput - case uint: - res = int64(typedInput) - case uint8: - res = int64(typedInput) - case uint16: - res = int64(typedInput) - case uint32: - res = int64(typedInput) - case uint64: - res = int64(typedInput) - default: - panic("invalid TTL value input type") + res := int64(input) + + // check if the input is of underlying type time.Duration + if durType, ok := any(input).(interface{ Seconds() float64 }); ok { + res = int64(durType.Seconds()) } // check if the value is negative or greater than the maximum value of uint32 @@ -70,7 +41,7 @@ func ToTTL[T ttlInput](input T) uint32 { // SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to // the minimum TTL. -func SetAnswerMinTTL[T ttlInput](msg *dns.Msg, min T) { +func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) { minTTL := ToTTL(min) for _, answer := range msg.Answer { if atomic.LoadUint32(&answer.Header().Ttl) < minTTL { @@ -81,7 +52,7 @@ func SetAnswerMinTTL[T ttlInput](msg *dns.Msg, min T) { // SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL // to the maximum TTL. -func SetAnswerMaxTTL[T ttlInput](msg *dns.Msg, max T) { +func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) { maxTTL := ToTTL(max) for _, answer := range msg.Answer { if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL && maxTTL != 0 { @@ -92,7 +63,7 @@ func SetAnswerMaxTTL[T ttlInput](msg *dns.Msg, max T) { // SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL // to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL. -func SetAnswerMinMaxTTL[T ttlInput](msg *dns.Msg, min, max T) { +func SetAnswerMinMaxTTL[T TTLInput, TT TTLInput](msg *dns.Msg, min T, max TT) { minTTL := ToTTL(min) maxTTL := ToTTL(max) @@ -124,7 +95,7 @@ func GetAnswerMinTTL(msg *dns.Msg) uint32 { // AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL // and the answer's TTL plus the specified adjustment. -func AdjustAnswerTTL[T ttlInput](msg *dns.Msg, adjustment T) { +func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) { minTTL := GetAnswerMinTTL(msg) adjustmentTTL := ToTTL(adjustment) From 27c60e74da970d43b0d949df3051db13bca9bcd8 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 17:31:47 +0000 Subject: [PATCH 07/24] linebreaks in doctag --- util/dns.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/util/dns.go b/util/dns.go index 4fb45fee9..e97dd63f9 100644 --- a/util/dns.go +++ b/util/dns.go @@ -14,8 +14,11 @@ type TTLInput interface { } // ToTTL converts the input to a TTL of seconds as uint32. +// // If the input is of underlying type time.Duration, the value is converted to seconds. +// // If the input is negative, the TTL is set to 0. +// // If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32. func ToTTL[T TTLInput](input T) uint32 { // use int64 as the intermediate type From 7f39983ffefa4c4af8a830ca9823538441c44444 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 17:33:16 +0000 Subject: [PATCH 08/24] refactoring --- resolver/caching_resolver.go | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index c6172e60c..f06ab9ad4 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -18,7 +18,7 @@ import ( const ( defaultCachingCleanUpInterval = 5 * time.Second // noCacheTTL indicates that a response should not be cached - noCacheTTL = time.Duration(-1) + noCacheTTL = uint32(0) ) // CachingResolver caches answers from dns queries with their TTL time, @@ -122,7 +122,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) } cacheCopy, ttl := r.createCacheEntry(logger, response.Res) - if cacheCopy == nil || !cacheableTTL(ttl) { + if cacheCopy == nil || ttl == noCacheTTL { return nil, 0 } @@ -137,7 +137,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) r.redisClient.PublishCache(cacheKey, cacheCopy) } - return &packed, ttl + return &packed, time.Duration(ttl) * time.Second } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -204,7 +204,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( response, err = r.next.Resolve(ctx, request) if err == nil { ttl := r.modifyResponseTTL(response.Res) - if cacheableTTL(ttl) { + if ttl > noCacheTTL { cacheCopy := r.putInCache(logger, cacheKey, response) if cacheCopy != nil && r.redisClient != nil { r.redisClient.PublishCache(cacheKey, cacheCopy) @@ -232,7 +232,7 @@ func (r *CachingResolver) getFromCache(logger *logrus.Entry, key string) *dns.Ms } // Adjust TTL - util.AdjustAnswerTTL(res, uint32(ttl.Seconds())) + util.AdjustAnswerTTL(res, ttl) return res } @@ -252,7 +252,7 @@ func isRequestCacheable(request *model.Request) bool { func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, response *model.Response) *dns.Msg { cacheCopy, ttl := r.createCacheEntry(logger, response.Res) - if cacheCopy == nil || !cacheableTTL(ttl) { + if cacheCopy == nil || ttl == noCacheTTL { return nil } @@ -263,42 +263,42 @@ func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, resp return nil } - r.resultCache.Put(cacheKey, &packed, ttl) + r.resultCache.Put(cacheKey, &packed, time.Duration(ttl)*time.Second) return cacheCopy } -func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) time.Duration { +func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { // if response is empty or negative, return negative cache time from config if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError { - return r.cfg.CacheTimeNegative.ToDuration() + return util.ToTTL(r.cfg.CacheTimeNegative) } // if response is truncated or CD flag is set, return noCacheTTL since we don't cache these responses if response.Truncated || response.CheckingDisabled { - return noCacheTTL + return 0 } // if response is not successful, return noCacheTTL since we don't cache these responses if response.Rcode != dns.RcodeSuccess { - return noCacheTTL + return 0 } // adjust TTLs of all answers to match the configured min and max caching times - util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime.SecondsU32(), r.cfg.MaxCachingTime.SecondsU32()) + util.SetAnswerMinMaxTTL(response, r.cfg.MinCachingTime, r.cfg.MaxCachingTime) - return time.Duration(util.GetAnswerMinTTL(response)) * time.Second + return util.GetAnswerMinTTL(response) } func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg, -) (*dns.Msg, time.Duration) { +) (*dns.Msg, uint32) { response := input.Copy() ttl := r.modifyResponseTTL(response) - if !cacheableTTL(ttl) { + if ttl == noCacheTTL { logger.Debug("response is not cacheable") - return nil, noCacheTTL + return nil, 0 } // don't cache any EDNS OPT records @@ -319,7 +319,3 @@ func (r *CachingResolver) FlushCaches(ctx context.Context) { logger.Debug("flush caches") r.resultCache.Clear() } - -func cacheableTTL(ttl time.Duration) bool { - return ttl > 0 -} From 66a7155650f4eac2c6dd7ab56c60569653baa39d Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 17:49:08 +0000 Subject: [PATCH 09/24] added ToTTLDuration --- resolver/caching_resolver.go | 4 ++-- util/dns.go | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index f06ab9ad4..4e4b88228 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -137,7 +137,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) r.redisClient.PublishCache(cacheKey, cacheCopy) } - return &packed, time.Duration(ttl) * time.Second + return &packed, util.ToTTLDuration(ttl) } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -263,7 +263,7 @@ func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, resp return nil } - r.resultCache.Put(cacheKey, &packed, time.Duration(ttl)*time.Second) + r.resultCache.Put(cacheKey, &packed, util.ToTTLDuration(ttl)) return cacheCopy } diff --git a/util/dns.go b/util/dns.go index e97dd63f9..9af1d09b1 100644 --- a/util/dns.go +++ b/util/dns.go @@ -3,6 +3,7 @@ package util import ( "math" "sync/atomic" + "time" "github.com/miekg/dns" ) @@ -42,6 +43,15 @@ func ToTTL[T TTLInput](input T) uint32 { return uint32(res) } +// ToTTLDuration converts the input to a time.Duration. +// +// If the input is of underlying type time.Duration, the value is returned as is. +// +// Otherwise the value is converted to seconds and returned as time.Duration. +func ToTTLDuration[T TTLInput](input T) time.Duration { + return time.Duration(ToTTL(input)) * time.Second +} + // SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to // the minimum TTL. func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) { From 6de43c84a69c4a3b6eed6c854905ccf434167bd5 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 18:13:28 +0000 Subject: [PATCH 10/24] added fast returns to reduce loops --- util/dns.go | 49 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/util/dns.go b/util/dns.go index 9af1d09b1..13119ca18 100644 --- a/util/dns.go +++ b/util/dns.go @@ -22,7 +22,12 @@ type TTLInput interface { // // If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32. func ToTTL[T TTLInput](input T) uint32 { - // use int64 as the intermediate type + // fast return if the input is already of type uint32 + if ui32Type, ok := any(input).(uint32); ok { + return ui32Type + } + + // use int64 as the intermediate type for conversion res := int64(input) // check if the input is of underlying type time.Duration @@ -55,10 +60,11 @@ func ToTTLDuration[T TTLInput](input T) time.Duration { // SetAnswerMinTTL sets the TTL of all answers in the message that are less than the specified minimum TTL to // the minimum TTL. func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) { - minTTL := ToTTL(min) - for _, answer := range msg.Answer { - if atomic.LoadUint32(&answer.Header().Ttl) < minTTL { - atomic.StoreUint32(&answer.Header().Ttl, minTTL) + if minTTL := ToTTL(min); minTTL != 0 { + for _, answer := range msg.Answer { + if atomic.LoadUint32(&answer.Header().Ttl) < minTTL { + atomic.StoreUint32(&answer.Header().Ttl, minTTL) + } } } } @@ -66,10 +72,11 @@ func SetAnswerMinTTL[T TTLInput](msg *dns.Msg, min T) { // SetAnswerMaxTTL sets the TTL of all answers in the message that are greater than the specified maximum TTL // to the maximum TTL. func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) { - maxTTL := ToTTL(max) - for _, answer := range msg.Answer { - if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL && maxTTL != 0 { - atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + if maxTTL := ToTTL(max); maxTTL != 0 { + for _, answer := range msg.Answer { + if atomic.LoadUint32(&answer.Header().Ttl) > maxTTL { + atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + } } } } @@ -80,12 +87,24 @@ func SetAnswerMinMaxTTL[T TTLInput, TT TTLInput](msg *dns.Msg, min T, max TT) { minTTL := ToTTL(min) maxTTL := ToTTL(max) - for _, answer := range msg.Answer { - headerTTL := atomic.LoadUint32(&answer.Header().Ttl) - if headerTTL < minTTL { - atomic.StoreUint32(&answer.Header().Ttl, minTTL) - } else if headerTTL > maxTTL && maxTTL != 0 { - atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + switch { + case minTTL == 0 && maxTTL == 0: + // no TTL specified, fast return + return + case minTTL != 0 && maxTTL == 0: + // only minimum TTL specified + SetAnswerMinTTL(msg, min) + case minTTL == 0 && maxTTL != 0: + // only maximum TTL specified + SetAnswerMaxTTL(msg, max) + default: + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + if headerTTL < minTTL { + atomic.StoreUint32(&answer.Header().Ttl, minTTL) + } else if headerTTL > maxTTL && maxTTL != 0 { + atomic.StoreUint32(&answer.Header().Ttl, maxTTL) + } } } } From 46441bafa13558d2159a21d3f0a94245dee9b393 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 18:22:40 +0000 Subject: [PATCH 11/24] fast return and comments --- util/dns.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/util/dns.go b/util/dns.go index 13119ca18..9686929d8 100644 --- a/util/dns.go +++ b/util/dns.go @@ -50,9 +50,7 @@ func ToTTL[T TTLInput](input T) uint32 { // ToTTLDuration converts the input to a time.Duration. // -// If the input is of underlying type time.Duration, the value is returned as is. -// -// Otherwise the value is converted to seconds and returned as time.Duration. +// The input is converted to a TTL of seconds as uint32 and then to a time.Duration. func ToTTLDuration[T TTLInput](input T) time.Duration { return time.Duration(ToTTL(input)) * time.Second } @@ -98,11 +96,12 @@ func SetAnswerMinMaxTTL[T TTLInput, TT TTLInput](msg *dns.Msg, min T, max TT) { // only maximum TTL specified SetAnswerMaxTTL(msg, max) default: + // both minimum and maximum TTL specified for _, answer := range msg.Answer { headerTTL := atomic.LoadUint32(&answer.Header().Ttl) if headerTTL < minTTL { atomic.StoreUint32(&answer.Header().Ttl, minTTL) - } else if headerTTL > maxTTL && maxTTL != 0 { + } else if headerTTL > maxTTL { atomic.StoreUint32(&answer.Header().Ttl, maxTTL) } } @@ -127,12 +126,15 @@ func GetAnswerMinTTL(msg *dns.Msg) uint32 { // AdjustAnswerTTL adjusts the TTL of all answers in the message by the difference between the lowest TTL // and the answer's TTL plus the specified adjustment. +// +// If the adjustment is zero, the TTL is not changed. func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) { - minTTL := GetAnswerMinTTL(msg) - adjustmentTTL := ToTTL(adjustment) + if adjustmentTTL := ToTTL(adjustment); adjustmentTTL != 0 { + minTTL := GetAnswerMinTTL(msg) - for _, answer := range msg.Answer { - headerTTL := atomic.LoadUint32(&answer.Header().Ttl) - atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL) + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL) + } } } From f28e58c689a37940ded1aef6c6537c9017b2e972 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 18:57:10 +0000 Subject: [PATCH 12/24] added dns unit tests --- util/dns_test.go | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 util/dns_test.go diff --git a/util/dns_test.go b/util/dns_test.go new file mode 100644 index 000000000..ef6874880 --- /dev/null +++ b/util/dns_test.go @@ -0,0 +1,37 @@ +package util + +import ( + "math" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("EDNS0 utils", func() { + DescribeTable("ToTTL", + func(input interface{}, expected int) { + res := uint32(0) + switch it := input.(type) { + case uint32: + res = ToTTL(it) + case int: + res = ToTTL(it) + case int64: + res = ToTTL(it) + case time.Duration: + res = ToTTL(it) + default: + Fail("unsupported type") + } + + Expect(ToTTL(res)).Should(Equal(uint32(expected))) + }, + Entry("should return 0 for negative input", -1, 0), + Entry("should return uint32 for uint32 input", uint32(1), 1), + Entry("should return uint32 for int input", 1, 1), + Entry("should return uint32 for int64 input", int64(1), 1), + Entry("should return seconds for time.Duration input", time.Second, 1), + Entry("should return math.MaxUint32 for too large input", int64(math.MaxUint32)+1, math.MaxUint32), + ) +}) From 5dc3a16a7ce74180466c0f13a426af5cdee68dab Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Wed, 17 Apr 2024 19:10:47 +0000 Subject: [PATCH 13/24] no fast return since adjustment is needed to set minttl --- util/dns.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/util/dns.go b/util/dns.go index 9686929d8..25b7b4cba 100644 --- a/util/dns.go +++ b/util/dns.go @@ -129,12 +129,11 @@ func GetAnswerMinTTL(msg *dns.Msg) uint32 { // // If the adjustment is zero, the TTL is not changed. func AdjustAnswerTTL[T TTLInput](msg *dns.Msg, adjustment T) { - if adjustmentTTL := ToTTL(adjustment); adjustmentTTL != 0 { - minTTL := GetAnswerMinTTL(msg) + adjustmentTTL := ToTTL(adjustment) + minTTL := GetAnswerMinTTL(msg) - for _, answer := range msg.Answer { - headerTTL := atomic.LoadUint32(&answer.Header().Ttl) - atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL) - } + for _, answer := range msg.Answer { + headerTTL := atomic.LoadUint32(&answer.Header().Ttl) + atomic.StoreUint32(&answer.Header().Ttl, headerTTL-minTTL+adjustmentTTL) } } From 37fbb423b72f019cabe3322f71efaecf9a0323ec Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sun, 21 Apr 2024 15:38:08 +0000 Subject: [PATCH 14/24] use correct context --- cache/expirationcache/expiration_cache.go | 6 +++--- cache/expirationcache/expiration_cache_test.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cache/expirationcache/expiration_cache.go b/cache/expirationcache/expiration_cache.go index ab37e3133..400308c68 100644 --- a/cache/expirationcache/expiration_cache.go +++ b/cache/expirationcache/expiration_cache.go @@ -103,14 +103,14 @@ func periodicCleanup[T any](ctx context.Context, c *ExpiringLRUCache[T]) { for { select { case <-ticker.C: - c.cleanUp() + c.cleanUp(ctx) case <-ctx.Done(): return } } } -func (e *ExpiringLRUCache[T]) cleanUp() { +func (e *ExpiringLRUCache[T]) cleanUp(ctx context.Context) { var expiredKeys []string // check for expired items and collect expired keys @@ -126,7 +126,7 @@ func (e *ExpiringLRUCache[T]) cleanUp() { var keysToDelete []string for _, key := range expiredKeys { - newVal, newTTL := e.preExpirationFn(context.Background(), key) + newVal, newTTL := e.preExpirationFn(ctx, key) if newVal != nil { e.Put(key, newVal, newTTL) } else { diff --git a/cache/expirationcache/expiration_cache_test.go b/cache/expirationcache/expiration_cache_test.go index 3e3b8b0c0..95bb46cb8 100644 --- a/cache/expirationcache/expiration_cache_test.go +++ b/cache/expirationcache/expiration_cache_test.go @@ -181,7 +181,7 @@ var _ = Describe("Expiration cache", func() { time.Sleep(2 * time.Millisecond) // trigger cleanUp manually -> onExpiredFn will be executed, because element is expired - cache.cleanUp() + cache.cleanUp(ctx) // wait for expiration val, ttl := cache.Get("key1") From 350b2fe4de8fe3754387d25103502ce3b254134e Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sun, 21 Apr 2024 15:53:45 +0000 Subject: [PATCH 15/24] redis refactoring 1 --- redis/redis.go | 58 +++++++++++++++------------------------------ redis/redis_test.go | 40 ++++++------------------------- 2 files changed, 26 insertions(+), 72 deletions(-) diff --git a/redis/redis.go b/redis/redis.go index 7c205f19e..df57561b0 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "math" "strings" "time" @@ -31,8 +30,9 @@ const ( // sendBuffer message type bufferMessage struct { + TTL uint32 Key string - Message *dns.Msg + Message []byte } // redis pubsub message @@ -126,10 +126,11 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) { return nil, err } -// PublishCache publish cache to redis async -func (c *Client) PublishCache(key string, message *dns.Msg) { - if len(key) > 0 && message != nil { +// PublishCache publish cache entry to redis if key and message are not empty and ttl > 0 +func (c *Client) PublishCache(key string, ttl uint32, message []byte) { + if len(key) > 0 && len(message) > 0 && ttl > 0 { c.sendBuffer <- &bufferMessage{ + TTL: ttl, Key: key, Message: message, } @@ -212,27 +213,20 @@ func (c *Client) startup(ctx context.Context) error { } func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) { - origRes := s.Message - origRes.Compress = true - binRes, pErr := origRes.Pack() - - if pErr == nil { - binMsg, mErr := json.Marshal(redisMessage{ - Key: s.Key, - Type: messageTypeCache, - Message: binRes, - Client: c.id, - }) - - if mErr == nil { - c.client.Publish(ctx, SyncChannelName, binMsg) - } - - c.client.Set(ctx, - prefixKey(s.Key), - binRes, - c.getTTL(origRes)) + psMsg, err := json.Marshal(redisMessage{ + Key: s.Key, + Type: messageTypeCache, + Message: s.Message, + Client: c.id, + }) + if err == nil { + c.client.Publish(ctx, SyncChannelName, psMsg) } + + c.client.Set(ctx, + prefixKey(s.Key), + s.Message, + util.ToTTLDuration(s.TTL)) } func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) { @@ -326,20 +320,6 @@ func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, er return nil, err } -// getTTL of dns message or return defaultCacheTime if 0 -func (c *Client) getTTL(dns *dns.Msg) time.Duration { - ttl := uint32(math.MaxInt32) - for _, a := range dns.Answer { - ttl = min(ttl, a.Header().Ttl) - } - - if ttl == 0 { - return defaultCacheTime - } - - return time.Duration(ttl) * time.Second -} - // prefixKey with CacheStorePrefix func prefixKey(key string) string { return fmt.Sprintf("%s%s", CacheStorePrefix, key) diff --git a/redis/redis_test.go b/redis/redis_test.go index 5c7f530a3..d38dd2db4 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -96,10 +96,12 @@ var _ = Describe("Redis client", func() { By("publish new message with TTL > 0", func() { res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") + Expect(err).Should(Succeed()) + binRes, err := res.Pack() Expect(err).Should(Succeed()) - redisClient.PublishCache("example.com", res) + redisClient.PublishCache("example.com", 123, binRes) }) By("Database has one entry with correct TTL", func() { @@ -111,34 +113,6 @@ var _ = Describe("Redis client", func() { Expect(ttl.Seconds()).Should(BeNumerically("~", 123)) }) }) - - It("One new entry with default TTL should be persisted in the database", func(ctx context.Context) { - redisClient, err = New(ctx, redisConfig) - Expect(err).Should(Succeed()) - - By("Database is empty", func() { - Eventually(func() []string { - return redisServer.DB(redisConfig.Database).Keys() - }).Should(BeEmpty()) - }) - - By("publish new message with TTL = 0", func() { - res, err := util.NewMsgWithAnswer("example.com.", 0, dns.Type(dns.TypeA), "123.124.122.123") - - Expect(err).Should(Succeed()) - - redisClient.PublishCache("example.com", res) - }) - - By("Database has one entry with default TTL", func() { - Eventually(func() bool { - return redisServer.DB(redisConfig.Database).Exists(exampleComKey) - }).Should(BeTrue()) - - ttl := redisServer.DB(redisConfig.Database).TTL(exampleComKey) - Expect(ttl.Seconds()).Should(BeNumerically("~", defaultCacheTime.Seconds())) - }) - }) }) When("Redis client publishes 'enabled' message", func() { It("should propagate the message over redis", func(ctx context.Context) { @@ -312,13 +286,13 @@ var _ = Describe("Redis client", func() { }) By("Put valid data in Redis by publishing the cache entry", func() { - var res *dns.Msg - - res, err = util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") + res, err := util.NewMsgWithAnswer("example.com.", 123, dns.Type(dns.TypeA), "123.124.122.123") + Expect(err).Should(Succeed()) + binRes, err := res.Pack() Expect(err).Should(Succeed()) - redisClient.PublishCache("example.com", res) + redisClient.PublishCache("example.com", 123, binRes) }) By("Database has one entry now", func() { From 423ed031ead8ac34ea57a515ce8f0f0863765328 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sun, 21 Apr 2024 16:14:50 +0000 Subject: [PATCH 16/24] caching_resolver refactoring --- resolver/caching_resolver.go | 69 ++++++++++++++---------------------- 1 file changed, 26 insertions(+), 43 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 4e4b88228..e7925a24c 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -121,23 +121,16 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) return nil, 0 } - cacheCopy, ttl := r.createCacheEntry(logger, response.Res) - if cacheCopy == nil || ttl == noCacheTTL { - return nil, 0 - } - - packed, err := cacheCopy.Pack() - if err != nil { - logger.WithError(err).WithError(err).Warn("response packing failed") - + ttl, res := r.createCacheEntry(logger, response.Res) + if ttl == noCacheTTL || len(res) == 0 { return nil, 0 } if r.redisClient != nil { - r.redisClient.PublishCache(cacheKey, cacheCopy) + r.redisClient.PublishCache(cacheKey, ttl, res) } - return &packed, util.ToTTLDuration(ttl) + return &res, util.ToTTLDuration(ttl) } func (r *CachingResolver) redisSubscriber(ctx context.Context) { @@ -153,7 +146,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { dlogger.Debug("received from redis") - r.putInCache(dlogger, rc.Key, rc.Response) + // TODO: Add to cache } case <-ctx.Done(): @@ -203,11 +196,12 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( response, err = r.next.Resolve(ctx, request) if err == nil { - ttl := r.modifyResponseTTL(response.Res) - if ttl > noCacheTTL { - cacheCopy := r.putInCache(logger, cacheKey, response) - if cacheCopy != nil && r.redisClient != nil { - r.redisClient.PublishCache(cacheKey, cacheCopy) + ttl, cacheEntry := r.createCacheEntry(logger, response.Res) + if ttl != noCacheTTL && len(cacheEntry) > 0 { + r.resultCache.Put(cacheKey, &cacheEntry, util.ToTTLDuration(ttl)) + + if r.redisClient != nil { + r.redisClient.PublishCache(cacheKey, ttl, cacheEntry) } } } @@ -250,24 +244,6 @@ func isRequestCacheable(request *model.Request) bool { return true } -func (r *CachingResolver) putInCache(logger *logrus.Entry, cacheKey string, response *model.Response) *dns.Msg { - cacheCopy, ttl := r.createCacheEntry(logger, response.Res) - if cacheCopy == nil || ttl == noCacheTTL { - return nil - } - - packed, err := cacheCopy.Pack() - if err != nil { - logger.WithError(err).Warn("response packing failed") - - return nil - } - - r.resultCache.Put(cacheKey, &packed, util.ToTTLDuration(ttl)) - - return cacheCopy -} - func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { // if response is empty or negative, return negative cache time from config if len(response.Answer) == 0 || response.Rcode == dns.RcodeNameError { @@ -290,21 +266,28 @@ func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { return util.GetAnswerMinTTL(response) } -func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg, -) (*dns.Msg, uint32) { - response := input.Copy() - - ttl := r.modifyResponseTTL(response) +func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg) (uint32, []byte) { + ttl := r.modifyResponseTTL(input) if ttl == noCacheTTL { logger.Debug("response is not cacheable") - return nil, 0 + return 0, nil } + internalMsg := input.Copy() + internalMsg.Compress = true + // don't cache any EDNS OPT records - util.RemoveEdns0Record(response) + util.RemoveEdns0Record(internalMsg) + + packed, err := internalMsg.Pack() + if err != nil { + logger.WithError(err).Warn("response packing failed") + + return 0, nil + } - return response, ttl + return ttl, packed } func (r *CachingResolver) publishMetricsIfEnabled(event string, val interface{}) { From 4928399387281369d390969dfcc90caa2fe06de9 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Sun, 21 Apr 2024 16:38:03 +0000 Subject: [PATCH 17/24] redis refactoring 2 --- redis/redis.go | 107 +++++++++++++++-------------------- redis/redis_test.go | 6 +- resolver/caching_resolver.go | 2 +- 3 files changed, 51 insertions(+), 64 deletions(-) diff --git a/redis/redis.go b/redis/redis.go index df57561b0..de642bcbc 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -10,7 +10,6 @@ import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/log" - "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" "github.com/go-redis/redis/v8" "github.com/google/uuid" @@ -28,11 +27,10 @@ const ( messageTypeEnable = 1 ) -// sendBuffer message -type bufferMessage struct { - TTL uint32 - Key string - Message []byte +type CacheEntry struct { + TTL uint32 + Key string + Entry []byte } // redis pubsub message @@ -43,12 +41,6 @@ type redisMessage struct { Client []byte `json:"c"` } -// CacheChannel message -type CacheMessage struct { - Key string - Response *model.Response -} - type EnabledMessage struct { State bool `json:"s"` Duration time.Duration `json:"d,omitempty"` @@ -61,8 +53,8 @@ type Client struct { client *redis.Client l *logrus.Entry id []byte - sendBuffer chan *bufferMessage - CacheChannel chan *CacheMessage + sendBuffer chan *CacheEntry + CacheChannel chan *CacheEntry EnabledChannel chan *EnabledMessage } @@ -111,8 +103,8 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) { client: rdb, l: log.PrefixedLog("redis"), id: id, - sendBuffer: make(chan *bufferMessage, chanCap), - CacheChannel: make(chan *CacheMessage, chanCap), + sendBuffer: make(chan *CacheEntry, chanCap), + CacheChannel: make(chan *CacheEntry, chanCap), EnabledChannel: make(chan *EnabledMessage, chanCap), } @@ -129,10 +121,10 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) { // PublishCache publish cache entry to redis if key and message are not empty and ttl > 0 func (c *Client) PublishCache(key string, ttl uint32, message []byte) { if len(key) > 0 && len(message) > 0 && ttl > 0 { - c.sendBuffer <- &bufferMessage{ - TTL: ttl, - Key: key, - Message: message, + c.sendBuffer <- &CacheEntry{ + TTL: ttl, + Key: key, + Entry: message, } } } @@ -212,11 +204,11 @@ func (c *Client) startup(ctx context.Context) error { return err } -func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) { +func (c *Client) publishMessageFromBuffer(ctx context.Context, s *CacheEntry) { psMsg, err := json.Marshal(redisMessage{ Key: s.Key, Type: messageTypeCache, - Message: s.Message, + Message: s.Entry, Client: c.id, }) if err == nil { @@ -225,7 +217,7 @@ func (c *Client) publishMessageFromBuffer(ctx context.Context, s *bufferMessage) c.client.Set(ctx, prefixKey(s.Key), - s.Message, + s.Entry, util.ToTTLDuration(s.TTL)) } @@ -242,7 +234,7 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) if !bytes.Equal(rm.Client, c.id) { switch rm.Type { case messageTypeCache: - var cm *CacheMessage + var cm *CacheEntry cm, err := convertMessage(&rm, 0) if err != nil { @@ -269,55 +261,50 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) } // getResponse returns model.Response for a key -func (c *Client) getResponse(ctx context.Context, key string) (*CacheMessage, error) { +func (c *Client) getResponse(ctx context.Context, key string) (*CacheEntry, error) { resp, err := c.client.Get(ctx, key).Result() - if err == nil { - var ttl time.Duration - ttl, err = c.client.TTL(ctx, key).Result() - - if err == nil { - var result *CacheMessage + if err != nil { + return nil, err + } - result, err = convertMessage(&redisMessage{ - Key: cleanKey(key), - Message: []byte(resp), - }, ttl) - if err != nil { - return nil, fmt.Errorf("conversion error: %w", err) - } + ttl, err := c.client.TTL(ctx, key).Result() + if err != nil { + return nil, err + } - return result, nil - } + result := CacheEntry{ + TTL: util.ToTTL(ttl), + Key: cleanKey(key), + Entry: []byte(resp), } - return nil, err + return &result, nil } // convertMessage converts redisMessage to CacheMessage -func convertMessage(message *redisMessage, ttl time.Duration) (*CacheMessage, error) { - msg := dns.Msg{} +func convertMessage(message *redisMessage, ttl time.Duration) (*CacheEntry, error) { + res := CacheEntry{ + TTL: util.ToTTL(ttl), + Key: message.Key, + Entry: message.Message, + } - err := msg.Unpack(message.Message) - if err == nil { - if ttl > 0 { - for _, a := range msg.Answer { - a.Header().Ttl = uint32(ttl.Seconds()) - } - } + // if ttl is set, use it + if res.TTL > 0 { + return &res, nil + } - res := &CacheMessage{ - Key: message.Key, - Response: &model.Response{ - RType: model.ResponseTypeCACHED, - Reason: cacheReason, - Res: &msg, - }, - } + // try to extract ttl from message + var msg *dns.Msg - return res, nil + err := msg.Unpack(message.Message) + if err != nil { + return nil, err } - return nil, err + res.TTL = util.GetAnswerMinTTL(msg) + + return &res, nil } // prefixKey with CacheStorePrefix diff --git a/redis/redis_test.go b/redis/redis_test.go index d38dd2db4..96f06a069 100644 --- a/redis/redis_test.go +++ b/redis/redis_test.go @@ -196,7 +196,7 @@ var _ = Describe("Redis client", func() { rec := redisServer.Publish(SyncChannelName, string(binMsg)) Expect(rec).Should(Equal(1)) - Eventually(func() chan *CacheMessage { + Eventually(func() chan *CacheEntry { return redisClient.CacheChannel }).Should(HaveLen(lenE + 1)) }, SpecTimeout(time.Second*6)) @@ -229,7 +229,7 @@ var _ = Describe("Redis client", func() { return redisClient.EnabledChannel }).Should(HaveLen(lenE)) - Eventually(func() chan *CacheMessage { + Eventually(func() chan *CacheEntry { return redisClient.CacheChannel }).Should(HaveLen(lenC)) }, SpecTimeout(time.Second*6)) @@ -262,7 +262,7 @@ var _ = Describe("Redis client", func() { return redisClient.EnabledChannel }).Should(HaveLen(lenE)) - Eventually(func() chan *CacheMessage { + Eventually(func() chan *CacheEntry { return redisClient.CacheChannel }).Should(HaveLen(lenC)) }, SpecTimeout(time.Second*6)) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index e7925a24c..35f250dec 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -146,7 +146,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { dlogger.Debug("received from redis") - // TODO: Add to cache + r.resultCache.Put(rc.Key, &rc.Entry, util.ToTTLDuration(rc.TTL)) } case <-ctx.Done(): From 17aca2ce01c6dbd6089452cdd78ec6b3846a3594 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Mon, 22 Apr 2024 19:42:43 +0000 Subject: [PATCH 18/24] fast return on zero or below --- util/dns.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/util/dns.go b/util/dns.go index 25b7b4cba..5d2797ebe 100644 --- a/util/dns.go +++ b/util/dns.go @@ -22,6 +22,11 @@ type TTLInput interface { // // If the input is greater than the maximum value of uint32, the TTL is set to math.MaxUint32. func ToTTL[T TTLInput](input T) uint32 { + // fast return if the input is zero or below + if input <= 0 { + return 0 + } + // fast return if the input is already of type uint32 if ui32Type, ok := any(input).(uint32); ok { return ui32Type From 05446fa9eddbc597df4f9c845416b57233f1aab0 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Mon, 22 Apr 2024 19:43:01 +0000 Subject: [PATCH 19/24] redis refactoring --- redis/redis.go | 40 ++++++++++++++++------------------------ 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/redis/redis.go b/redis/redis.go index de642bcbc..53af049d6 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -204,6 +204,7 @@ func (c *Client) startup(ctx context.Context) error { return err } +// publishMessageFromBuffer publishes a message from the buffer to the redis channel and stores it in the cache func (c *Client) publishMessageFromBuffer(ctx context.Context, s *CacheEntry) { psMsg, err := json.Marshal(redisMessage{ Key: s.Key, @@ -234,16 +235,14 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) if !bytes.Equal(rm.Client, c.id) { switch rm.Type { case messageTypeCache: - var cm *CacheEntry - - cm, err := convertMessage(&rm, 0) + cm, err := convertMessage(0, rm.Key, rm.Message) if err != nil { - c.l.Error("Processing CacheMessage error: ", err) + c.l.Error(err) return } - util.CtxSend(ctx, c.CacheChannel, cm) + util.CtxSend(ctx, c.CacheChannel, &cm) case messageTypeEnable: var msg EnabledMessage @@ -272,39 +271,32 @@ func (c *Client) getResponse(ctx context.Context, key string) (*CacheEntry, erro return nil, err } - result := CacheEntry{ - TTL: util.ToTTL(ttl), - Key: cleanKey(key), - Entry: []byte(resp), + result, err := convertMessage(ttl, cleanKey(key), []byte(resp)) + if err != nil { + return nil, err } return &result, nil } // convertMessage converts redisMessage to CacheMessage -func convertMessage(message *redisMessage, ttl time.Duration) (*CacheEntry, error) { +func convertMessage[T util.TTLInput](ttl T, key string, message []byte) (CacheEntry, error) { res := CacheEntry{ TTL: util.ToTTL(ttl), - Key: message.Key, - Entry: message.Message, - } - - // if ttl is set, use it - if res.TTL > 0 { - return &res, nil + Key: key, + Entry: message, } - // try to extract ttl from message var msg *dns.Msg - - err := msg.Unpack(message.Message) - if err != nil { - return nil, err + if err := msg.Unpack(message); err != nil { + return res, fmt.Errorf("invalid message for key %s", key) } - res.TTL = util.GetAnswerMinTTL(msg) + if ttl == 0 { + res.TTL = util.GetAnswerMinTTL(msg) + } - return &res, nil + return res, nil } // prefixKey with CacheStorePrefix From c7986899b8c56022f7899bd3d54a9822d964824e Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Mon, 22 Apr 2024 19:48:43 +0000 Subject: [PATCH 20/24] fix caching resolver test --- resolver/caching_resolver_test.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/resolver/caching_resolver_test.go b/resolver/caching_resolver_test.go index 5081ab543..20caa196c 100644 --- a/resolver/caching_resolver_test.go +++ b/resolver/caching_resolver_test.go @@ -782,13 +782,13 @@ var _ = Describe("CachingResolver", func() { request := newRequest("example2.com.", A) domain := util.ExtractDomain(request.Req.Question[0]) cacheKey := util.GenerateCacheKey(A, domain) - redisMockMsg := &redis.CacheMessage{ - Key: cacheKey, - Response: &Response{ - RType: ResponseTypeCACHED, - Reason: "MOCK_REDIS", - Res: mockAnswer, - }, + binMsg, err := mockAnswer.Pack() + Expect(err).Should(Succeed()) + + redisMockMsg := &redis.CacheEntry{ + TTL: 123, + Key: cacheKey, + Entry: binMsg, } redisClient.CacheChannel <- redisMockMsg From f8093f30d8861ee574caab6501ea242aa214ecbb Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Mon, 22 Apr 2024 19:54:29 +0000 Subject: [PATCH 21/24] nil check --- redis/redis.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/redis/redis.go b/redis/redis.go index 53af049d6..594fc1f01 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -287,9 +287,14 @@ func convertMessage[T util.TTLInput](ttl T, key string, message []byte) (CacheEn Entry: message, } + packErr := fmt.Errorf("invalid message for key %s", key) + if len(message) == 0 { + return res, packErr + } + var msg *dns.Msg if err := msg.Unpack(message); err != nil { - return res, fmt.Errorf("invalid message for key %s", key) + return res, packErr } if ttl == 0 { From 61c83af493513011274a7f5bb8df4d7514afc2f1 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Mon, 22 Apr 2024 19:57:57 +0000 Subject: [PATCH 22/24] new pointers --- redis/redis.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/redis/redis.go b/redis/redis.go index 594fc1f01..935cd677f 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -244,15 +244,15 @@ func (c *Client) processReceivedMessage(ctx context.Context, msg *redis.Message) util.CtxSend(ctx, c.CacheChannel, &cm) case messageTypeEnable: - var msg EnabledMessage + msg := new(EnabledMessage) - if err := json.Unmarshal(rm.Message, &msg); err != nil { + if err := json.Unmarshal(rm.Message, msg); err != nil { c.l.Error("Processing EnabledMessage error: ", err) return } - util.CtxSend(ctx, c.EnabledChannel, &msg) + util.CtxSend(ctx, c.EnabledChannel, msg) default: c.l.Warn("Unknown message type: ", rm.Type) } @@ -292,7 +292,7 @@ func convertMessage[T util.TTLInput](ttl T, key string, message []byte) (CacheEn return res, packErr } - var msg *dns.Msg + msg := new(dns.Msg) if err := msg.Unpack(message); err != nil { return res, packErr } From 73cbd9c15e77f1966c0004e6a0f4866cf1acb8fb Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Mon, 22 Apr 2024 20:06:00 +0000 Subject: [PATCH 23/24] removed const --- resolver/caching_resolver.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/resolver/caching_resolver.go b/resolver/caching_resolver.go index 35f250dec..cf30cabf7 100644 --- a/resolver/caching_resolver.go +++ b/resolver/caching_resolver.go @@ -17,8 +17,6 @@ import ( const ( defaultCachingCleanUpInterval = 5 * time.Second - // noCacheTTL indicates that a response should not be cached - noCacheTTL = uint32(0) ) // CachingResolver caches answers from dns queries with their TTL time, @@ -122,7 +120,7 @@ func (r *CachingResolver) reloadCacheEntry(ctx context.Context, cacheKey string) } ttl, res := r.createCacheEntry(logger, response.Res) - if ttl == noCacheTTL || len(res) == 0 { + if ttl == 0 || len(res) == 0 { return nil, 0 } @@ -142,9 +140,7 @@ func (r *CachingResolver) redisSubscriber(ctx context.Context) { if rc != nil { _, domain := util.ExtractCacheKey(rc.Key) - dlogger := logger.WithField("domain", util.Obfuscate(domain)) - - dlogger.Debug("received from redis") + logger.WithField("domain", util.Obfuscate(domain)).Debug("received from redis") r.resultCache.Put(rc.Key, &rc.Entry, util.ToTTLDuration(rc.TTL)) } @@ -197,7 +193,7 @@ func (r *CachingResolver) Resolve(ctx context.Context, request *model.Request) ( response, err = r.next.Resolve(ctx, request) if err == nil { ttl, cacheEntry := r.createCacheEntry(logger, response.Res) - if ttl != noCacheTTL && len(cacheEntry) > 0 { + if ttl != 0 && len(cacheEntry) > 0 { r.resultCache.Put(cacheKey, &cacheEntry, util.ToTTLDuration(ttl)) if r.redisClient != nil { @@ -268,7 +264,7 @@ func (r *CachingResolver) modifyResponseTTL(response *dns.Msg) uint32 { func (r *CachingResolver) createCacheEntry(logger *logrus.Entry, input *dns.Msg) (uint32, []byte) { ttl := r.modifyResponseTTL(input) - if ttl == noCacheTTL { + if ttl == 0 { logger.Debug("response is not cacheable") return 0, nil From 7b2fcc2953a58d3e294a6105d96cd69514960368 Mon Sep 17 00:00:00 2001 From: Kwitsch Date: Tue, 23 Apr 2024 16:55:48 +0000 Subject: [PATCH 24/24] redis refactoring 2 --- redis/redis.go | 116 +++++++++++++++++++++++++++---------------------- util/dns.go | 2 +- 2 files changed, 64 insertions(+), 54 deletions(-) diff --git a/redis/redis.go b/redis/redis.go index 935cd677f..82dab7659 100644 --- a/redis/redis.go +++ b/redis/redis.go @@ -92,30 +92,34 @@ func New(ctx context.Context, cfg *config.Redis) (*Client, error) { rdb := baseClient.WithContext(ctx) _, err := rdb.Ping(ctx).Result() - if err == nil { - var id []byte - - id, err = uuid.New().MarshalBinary() - if err == nil { - // construct client - res := &Client{ - config: cfg, - client: rdb, - l: log.PrefixedLog("redis"), - id: id, - sendBuffer: make(chan *CacheEntry, chanCap), - CacheChannel: make(chan *CacheEntry, chanCap), - EnabledChannel: make(chan *EnabledMessage, chanCap), - } + if err != nil { + return nil, err + } - // start channel handling go routine - err = res.startup(ctx) + var id []byte - return res, err - } + id, err = uuid.New().MarshalBinary() + if err != nil { + return nil, err + } + // construct client + res := &Client{ + config: cfg, + client: rdb, + l: log.PrefixedLog("redis"), + id: id, + sendBuffer: make(chan *CacheEntry, chanCap), + CacheChannel: make(chan *CacheEntry, chanCap), + EnabledChannel: make(chan *EnabledMessage, chanCap), } - return nil, err + // start channel handling go routine + err = res.startup(ctx) + if err != nil { + return nil, err + } + + return res, nil } // PublishCache publish cache entry to redis if key and message are not empty and ttl > 0 @@ -130,18 +134,21 @@ func (c *Client) PublishCache(key string, ttl uint32, message []byte) { } func (c *Client) PublishEnabled(ctx context.Context, state *EnabledMessage) { - binState, sErr := json.Marshal(state) - if sErr == nil { - binMsg, mErr := json.Marshal(redisMessage{ - Type: messageTypeEnable, - Message: binState, - Client: c.id, - }) + binState, err := json.Marshal(state) + if err != nil { + return + } - if mErr == nil { - c.client.Publish(ctx, SyncChannelName, binMsg) - } + binMsg, err := json.Marshal(redisMessage{ + Type: messageTypeEnable, + Message: binState, + Client: c.id, + }) + if err != nil { + return } + + c.client.Publish(ctx, SyncChannelName, binMsg) } // GetRedisCache reads the redis cache and publish it to the channel @@ -176,32 +183,35 @@ func (c *Client) startup(ctx context.Context) error { ps := c.client.Subscribe(ctx, SyncChannelName) _, err := ps.Receive(ctx) - if err == nil { - go func() { - for { - select { - // received message from subscription - case msg := <-ps.Channel(): - c.l.Debug("Received message: ", msg) - - if msg != nil && len(msg.Payload) > 0 { - // message is not empty - c.processReceivedMessage(ctx, msg) - } - // publish message from buffer - case s := <-c.sendBuffer: - c.publishMessageFromBuffer(ctx, s) - // context is done - case <-ctx.Done(): - c.client.Close() - - return + if err != nil { + return err + } + + go func() { + defer ps.Close() + defer c.client.Close() + + for { + select { + // received message from subscription + case msg := <-ps.Channel(): + c.l.Debug("Received message: ", msg) + + if msg != nil && len(msg.Payload) > 0 { + // message is not empty + c.processReceivedMessage(ctx, msg) } + // publish message from buffer + case s := <-c.sendBuffer: + c.publishMessageFromBuffer(ctx, s) + // context is done + case <-ctx.Done(): + return } - }() - } + } + }() - return err + return nil } // publishMessageFromBuffer publishes a message from the buffer to the redis channel and stores it in the cache diff --git a/util/dns.go b/util/dns.go index 5d2797ebe..fcf635ec3 100644 --- a/util/dns.go +++ b/util/dns.go @@ -86,7 +86,7 @@ func SetAnswerMaxTTL[T TTLInput](msg *dns.Msg, max T) { // SetAnswerMinMaxTTL sets the TTL of all answers in the message that are less than the specified minimum TTL // to the minimum TTL and the TTL of all answers that are greater than the specified maximum TTL to the maximum TTL. -func SetAnswerMinMaxTTL[T TTLInput, TT TTLInput](msg *dns.Msg, min T, max TT) { +func SetAnswerMinMaxTTL[T, TT TTLInput](msg *dns.Msg, min T, max TT) { minTTL := ToTTL(min) maxTTL := ToTTL(max)