diff --git a/internal/webconnectivityalgo/dnswhoami.go b/internal/webconnectivityalgo/dnswhoami.go index 084eadc5c..529f94fd7 100644 --- a/internal/webconnectivityalgo/dnswhoami.go +++ b/internal/webconnectivityalgo/dnswhoami.go @@ -14,14 +14,21 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" + "github.com/ooni/probe-cli/v3/internal/optional" ) // DNSWhoamiInfoEntry contains an entry for DNSWhoamiInfo. type DNSWhoamiInfoEntry struct { - // Address is the IP address + // Address is the IP address used by the resolver. Address string `json:"address"` } +// dnsWhoamiInfoTimedEntry keeps an address and the time we created the entry together. +type dnsWhoamiInfoTimedEntry struct { + Addr string + T time.Time +} + // TODO(bassosimone): this code needs refining before we can merge it inside // master. For one, we already have systemv4 info. Additionally, it would // be neat to avoid additional AAAA queries. Furthermore, we should also see @@ -30,27 +37,25 @@ type DNSWhoamiInfoEntry struct { // TODO(bassosimone): consider factoring this code and keeping state // on disk rather than on memory. -// TODO(bassosimone): we should periodically invalidate the whoami lookup results. - // DNSWhoamiService is a service that performs DNS whoami lookups. // // The zero value of this struct is invalid. Please, construct using // the [NewDNSWhoamiService] factory function. type DNSWhoamiService struct { - // logger is the logger + // entries contains the entries. + entries map[string]*dnsWhoamiInfoTimedEntry + + // logger is the logger. logger model.Logger - // mu provides mutual exclusion + // mu provides mutual exclusion. mu *sync.Mutex - // netx is the underlying network we're using + // netx is the underlying network we're using. netx *netxlite.Netx - // systemv4 contains systemv4 results - systemv4 []DNSWhoamiInfoEntry - - // udpv4 contains udpv4 results - udpv4 map[string][]DNSWhoamiInfoEntry + // timeNow allows to get the current time. + timeNow func() time.Time // whoamiDomain is the whoamiDomain to query for. whoamiDomain string @@ -59,53 +64,115 @@ type DNSWhoamiService struct { // NewDNSWhoamiService constructs a new [*DNSWhoamiService]. func NewDNSWhoamiService(logger model.Logger) *DNSWhoamiService { return &DNSWhoamiService{ + entries: map[string]*dnsWhoamiInfoTimedEntry{}, logger: logger, mu: &sync.Mutex{}, netx: &netxlite.Netx{Underlying: nil}, - systemv4: []DNSWhoamiInfoEntry{}, - udpv4: map[string][]DNSWhoamiInfoEntry{}, + timeNow: time.Now, whoamiDomain: "whoami.v4.powerdns.org", } } // SystemV4 returns the results of querying using the system resolver and IPv4. func (svc *DNSWhoamiService) SystemV4(ctx context.Context) ([]DNSWhoamiInfoEntry, bool) { - svc.mu.Lock() - defer svc.mu.Unlock() - if len(svc.systemv4) <= 0 { - ctx, cancel := context.WithTimeout(ctx, 4*time.Second) - defer cancel() - reso := svc.netx.NewStdlibResolver(svc.logger) - addrs, err := reso.LookupHost(ctx, svc.whoamiDomain) - if err != nil || len(addrs) < 1 { - return nil, false - } - svc.systemv4 = []DNSWhoamiInfoEntry{{ - Address: addrs[0], - }} + spec := &dnsWhoamiResolverSpec{ + name: "system:///", + factory: func(logger model.Logger, netx *netxlite.Netx) model.Resolver { + return svc.netx.NewStdlibResolver(svc.logger) + }, } - return svc.systemv4, len(svc.systemv4) > 0 + v := svc.lookup(ctx, spec) + return v, len(v) > 0 } // UDPv4 returns the results of querying a given UDP resolver and IPv4. func (svc *DNSWhoamiService) UDPv4(ctx context.Context, address string) ([]DNSWhoamiInfoEntry, bool) { + spec := &dnsWhoamiResolverSpec{ + name: address, + factory: func(logger model.Logger, netx *netxlite.Netx) model.Resolver { + dialer := svc.netx.NewDialerWithResolver(svc.logger, svc.netx.NewStdlibResolver(svc.logger)) + return svc.netx.NewParallelUDPResolver(svc.logger, dialer, address) + }, + } + v := svc.lookup(ctx, spec) + return v, len(v) > 0 +} + +type dnsWhoamiResolverSpec struct { + name string + factory func(logger model.Logger, netx *netxlite.Netx) model.Resolver +} + +func (svc *DNSWhoamiService) lookup(ctx context.Context, spec *dnsWhoamiResolverSpec) []DNSWhoamiInfoEntry { + // get the current time + now := svc.timeNow() + + // possibly use cache + mentry := svc.lockAndGet(now, spec.name) + if !mentry.IsNone() { + return []DNSWhoamiInfoEntry{mentry.Unwrap()} + } + + // perform lookup + ctx, cancel := context.WithTimeout(ctx, 4*time.Second) + defer cancel() + reso := spec.factory(svc.logger, svc.netx) + addrs, err := reso.LookupHost(ctx, svc.whoamiDomain) + if err != nil || len(addrs) < 1 { + return nil + } + + // update cache + svc.lockAndUpdate(now, spec.name, addrs[0]) + + // return to the caller + return []DNSWhoamiInfoEntry{{Address: addrs[0]}} +} + +func (svc *DNSWhoamiService) lockAndGet(now time.Time, serverAddr string) optional.Value[DNSWhoamiInfoEntry] { + // ensure there's mutual exclusion + defer svc.mu.Unlock() + svc.mu.Lock() + + // see if there's an entry + entry, found := svc.entries[serverAddr] + if !found { + return optional.None[DNSWhoamiInfoEntry]() + } + + // make sure the entry has not expired + const validity = 45 * time.Second + if now.Sub(entry.T) > validity { + return optional.None[DNSWhoamiInfoEntry]() + } + + // return a copy of the value + return optional.Some(DNSWhoamiInfoEntry{ + Address: entry.Addr, + }) +} + +func (svc *DNSWhoamiService) lockAndUpdate(now time.Time, serverAddr, whoamiAddr string) { + // ensure there's mutual exclusion + defer svc.mu.Unlock() svc.mu.Lock() + + // insert into the table + svc.entries[serverAddr] = &dnsWhoamiInfoTimedEntry{ + Addr: whoamiAddr, + T: now, + } +} + +func (svc *DNSWhoamiService) cloneEntries() map[string]*dnsWhoamiInfoTimedEntry { defer svc.mu.Unlock() - if len(svc.udpv4[address]) <= 0 { - ctx, cancel := context.WithTimeout(ctx, 4*time.Second) - defer cancel() - dialer := svc.netx.NewDialerWithResolver(svc.logger, svc.netx.NewStdlibResolver(svc.logger)) - reso := svc.netx.NewParallelUDPResolver(svc.logger, dialer, address) - // TODO(bassosimone): this should actually only send an A query. Sending an AAAA - // query is _way_ unnecessary since we know that only A is going to work. - addrs, err := reso.LookupHost(ctx, svc.whoamiDomain) - if err != nil || len(addrs) < 1 { - return nil, false + svc.mu.Lock() + output := make(map[string]*dnsWhoamiInfoTimedEntry) + for key, value := range svc.entries { + output[key] = &dnsWhoamiInfoTimedEntry{ + Addr: value.Addr, + T: value.T, } - svc.udpv4[address] = []DNSWhoamiInfoEntry{{ - Address: addrs[0], - }} } - value := svc.udpv4[address] - return value, len(value) > 0 + return output } diff --git a/internal/webconnectivityalgo/dnswhoami_test.go b/internal/webconnectivityalgo/dnswhoami_test.go index 0de0c8b01..427541b09 100644 --- a/internal/webconnectivityalgo/dnswhoami_test.go +++ b/internal/webconnectivityalgo/dnswhoami_test.go @@ -3,6 +3,7 @@ package webconnectivityalgo import ( "context" "testing" + "time" "github.com/apex/log" "github.com/google/go-cmp/cmp" @@ -11,8 +12,8 @@ import ( ) func TestDNSWhoamiService(t *testing.T) { - // expectation describes expectations - type expectation struct { + // callResults contains the results of calling System or UDPv4 + type callResults struct { Entries []DNSWhoamiInfoEntry Good bool } @@ -25,14 +26,17 @@ func TestDNSWhoamiService(t *testing.T) { // domain is the domain to query for domain string - // expectations contains the expecations - expectations []expectation + // internals contains the expected internals cache + internals map[string]*dnsWhoamiInfoTimedEntry + + // callResults contains the expectations + callResults []callResults } cases := []testcase{{ name: "common case using the default domain", domain: "", // forces using default - expectations: []expectation{{ + callResults: []callResults{{ Entries: []DNSWhoamiInfoEntry{{ Address: netemx.DefaultClientAddress, }}, @@ -43,16 +47,27 @@ func TestDNSWhoamiService(t *testing.T) { }}, Good: true, }}, + internals: map[string]*dnsWhoamiInfoTimedEntry{ + "system:///": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(time.Second), + }, + "8.8.8.8:53": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(2 * time.Second), + }, + }, }, { name: "error case using another domain", domain: "example.xyz", - expectations: []expectation{{ + callResults: []callResults{{ Entries: nil, Good: false, }, { Entries: nil, Good: false, }}, + internals: map[string]*dnsWhoamiInfoTimedEntry{}, }} for _, tc := range cases { @@ -69,29 +84,142 @@ func TestDNSWhoamiService(t *testing.T) { if tc.domain != "" { svc.whoamiDomain = tc.domain } + svc.timeNow = (&testTimeProvider{ + t0: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC), + times: []time.Duration{ + time.Second, + 2 * time.Second, + }, + idx: 0, + }).timeNow // prepare collecting results - var results []expectation + var results []callResults // run with the system resolver sysEntries, sysGood := svc.SystemV4(context.Background()) - results = append(results, expectation{ + results = append(results, callResults{ Entries: sysEntries, Good: sysGood, }) // run with an UDP resolver udpEntries, udpGood := svc.UDPv4(context.Background(), "8.8.8.8:53") - results = append(results, expectation{ + results = append(results, callResults{ Entries: udpEntries, Good: udpGood, }) // check whether we've got what we expected - if diff := cmp.Diff(tc.expectations, results); diff != "" { + if diff := cmp.Diff(tc.callResults, results); diff != "" { + t.Fatal(diff) + } + + // check the internals + if diff := cmp.Diff(tc.internals, svc.cloneEntries()); diff != "" { t.Fatal(diff) } }) } + t.Run("we correctly handle cache expiration", func(t *testing.T) { + // create testing scenario + env := netemx.MustNewScenario(netemx.InternetScenario) + defer env.Close() + + // create the service + svc := NewDNSWhoamiService(log.Log) + + // create the timeTestProvider + ttp := &testTimeProvider{ + t0: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC), + times: []time.Duration{ + // first run + time.Second, + 2 * time.Second, + // second run + 15 * time.Second, + 17 * time.Second, + // third run + 60 * time.Second, + 62 * time.Second, + }, + idx: 0, + } + + // override fields + svc.netx = &netxlite.Netx{Underlying: &netxlite.NetemUnderlyingNetworkAdapter{UNet: env.ClientStack}} + svc.timeNow = ttp.timeNow + + // run for the first time + _, _ = svc.SystemV4(context.Background()) + _, _ = svc.UDPv4(context.Background(), "8.8.8.8:53") + + // establish expectations for first run + // + // we expect the internals to be related to the first run + expectFirstInternals := map[string]*dnsWhoamiInfoTimedEntry{ + "system:///": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(time.Second), + }, + "8.8.8.8:53": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(2 * time.Second), + }, + } + + // check the internals for the first run + if diff := cmp.Diff(expectFirstInternals, svc.cloneEntries()); diff != "" { + t.Fatal(diff) + } + + // run for the second time + _, _ = svc.SystemV4(context.Background()) + _, _ = svc.UDPv4(context.Background(), "8.8.8.8:53") + + // establish expectations for second run + // + // we expect the internals to be related to the first run because not + // enough time has elapsed since we create the cache entries + expectSecondInternals := map[string]*dnsWhoamiInfoTimedEntry{ + "system:///": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(time.Second), + }, + "8.8.8.8:53": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(2 * time.Second), + }, + } + + // check the internals for the second run + if diff := cmp.Diff(expectSecondInternals, svc.cloneEntries()); diff != "" { + t.Fatal(diff) + } + + // run for the third time + _, _ = svc.SystemV4(context.Background()) + _, _ = svc.UDPv4(context.Background(), "8.8.8.8:53") + + // establish expectations for third run + // + // we expect the cache to be related to the third run because now the + // entries are stale and so we perform another lookup + expectThirdInternals := map[string]*dnsWhoamiInfoTimedEntry{ + "system:///": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(60 * time.Second), + }, + "8.8.8.8:53": { + Addr: netemx.DefaultClientAddress, + T: time.Date(2024, 2, 8, 9, 8, 7, 6, time.UTC).Add(62 * time.Second), + }, + } + + // check the internals for the second run + if diff := cmp.Diff(expectThirdInternals, svc.cloneEntries()); diff != "" { + t.Fatal(diff) + } + }) }