diff --git a/internal/bytecounter/context.go b/internal/bytecounter/context.go index 3817250be5..0c593633dc 100644 --- a/internal/bytecounter/context.go +++ b/internal/bytecounter/context.go @@ -11,7 +11,7 @@ import ( type byteCounterSessionKey struct{} -// ContextSessionByteCounter retrieves the session byte counter from the context +// ContextSessionByteCounter retrieves the possibly-nil session byte counter from the context. func ContextSessionByteCounter(ctx context.Context) *Counter { counter, _ := ctx.Value(byteCounterSessionKey{}).(*Counter) return counter @@ -24,7 +24,7 @@ func WithSessionByteCounter(ctx context.Context, counter *Counter) context.Conte type byteCounterExperimentKey struct{} -// ContextExperimentByteCounter retrieves the experiment byte counter from the context +// ContextExperimentByteCounter retrieves the possibly-nil experiment byte counter from the context. func ContextExperimentByteCounter(ctx context.Context) *Counter { counter, _ := ctx.Value(byteCounterExperimentKey{}).(*Counter) return counter diff --git a/internal/bytecounter/dialer.go b/internal/bytecounter/dialer.go index b016bf491f..fb6cafbff8 100644 --- a/internal/bytecounter/dialer.go +++ b/internal/bytecounter/dialer.go @@ -14,7 +14,7 @@ import ( // MaybeWrapWithContextAwareDialer wraps the given dialer with a ContextAwareDialer // if the enabled argument is true and otherwise just returns the given dialer. // -// # Bug +// # Caveat // // This implementation cannot properly account for the bytes that are sent by // persistent connections, because they stick to the counters set when the @@ -22,8 +22,7 @@ import ( // received when submitting a measurement. Such bytes are specifically not // seen by the experiment specific byte counter. // -// For this reason, this implementation may be heavily changed/removed -// in the future (<- this message is now ~two years old, though). +// As such, this implementation should only be used when measuring. func MaybeWrapWithContextAwareDialer(enabled bool, dialer model.Dialer) model.Dialer { if !enabled { return dialer diff --git a/internal/bytecounter/resolver.go b/internal/bytecounter/resolver.go index 1a86153b73..453ea931f4 100644 --- a/internal/bytecounter/resolver.go +++ b/internal/bytecounter/resolver.go @@ -8,11 +8,59 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// WrapWithContextAwareSystemResolver wraps the given resolver with a resolver that +// is aware of context-byte counting. See MaybeWrapSystemResolver for a list of caveats. +func WrapWithContextAwareSystemResolver(reso model.Resolver) model.Resolver { + return &ContextAwareSystemResolver{reso} +} + +// ContextAwareSystemResolver is a [model.Resolver] that knows how to count bytes +// sent and received. We typically use this for the system resolver only because for +// other resolvers we are better off just wrapping their connections. +type ContextAwareSystemResolver struct { + R model.Resolver +} + +// Address implements model.Resolver. +func (r *ContextAwareSystemResolver) Address() string { + return r.R.Address() +} + +// CloseIdleConnections implements model.Resolver. +func (r *ContextAwareSystemResolver) CloseIdleConnections() { + r.R.CloseIdleConnections() +} + +func (r *ContextAwareSystemResolver) wrap(ctx context.Context) model.Resolver { + return MaybeWrapSystemResolver(MaybeWrapSystemResolver( + r.R, ContextSessionByteCounter(ctx)), ContextExperimentByteCounter(ctx)) +} + +// LookupHTTPS implements model.Resolver. +func (r *ContextAwareSystemResolver) LookupHTTPS(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return r.wrap(ctx).LookupHTTPS(ctx, domain) +} + +// LookupHost implements model.Resolver. +func (r *ContextAwareSystemResolver) LookupHost(ctx context.Context, hostname string) (addrs []string, err error) { + return r.wrap(ctx).LookupHost(ctx, hostname) +} + +// LookupNS implements model.Resolver. +func (r *ContextAwareSystemResolver) LookupNS(ctx context.Context, domain string) ([]*net.NS, error) { + return r.wrap(ctx).LookupNS(ctx, domain) +} + +// Network implements model.Resolver. +func (r *ContextAwareSystemResolver) Network() string { + return r.R.Network() +} + // MaybeWrapSystemResolver takes in input a Resolver and either wraps it // to perform byte counting, if this counter is not nil, or just returns to the // caller the original resolver, when the counter is nil. // -// # Bug +// # Caveat // // The returned resolver will only approximately estimate the bytes // sent and received by this resolver if this resolver is the system diff --git a/internal/bytecounter/resolver_test.go b/internal/bytecounter/resolver_test.go index f937debc80..37ab4f0498 100644 --- a/internal/bytecounter/resolver_test.go +++ b/internal/bytecounter/resolver_test.go @@ -283,3 +283,274 @@ func TestMaybeWrapSystemResolver(t *testing.T) { } }) } + +func TestWrapWithContextAwareSystemResolver(t *testing.T) { + t.Run("Address works as intended", func(t *testing.T) { + underlying := &mocks.Resolver{ + MockAddress: func() string { + return "8.8.8.8:53" + }, + } + reso := WrapWithContextAwareSystemResolver(underlying) + if reso.Address() != "8.8.8.8:53" { + t.Fatal("unexpected result") + } + }) + + t.Run("CloseIdleConnections works as intended", func(t *testing.T) { + var called bool + underlying := &mocks.Resolver{ + MockCloseIdleConnections: func() { + called = true + }, + } + reso := WrapWithContextAwareSystemResolver(underlying) + reso.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) + + t.Run("LookupHTTPS works as intended", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + expected := &model.HTTPSSvc{} + underlying := &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return expected, nil + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupHTTPS(ctx, "dns.google") + if err != nil { + t.Fatal("unexpected error", err) + } + if got != expected { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 10 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 256 { + t.Fatal("unexpected nrecv") + } + }) + + t.Run("on non-DNS failure", func(t *testing.T) { + expected := errors.New("mocked error") + underlying := &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return nil, expected + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupHTTPS(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if got != nil { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 10 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 0 { + t.Fatal("unexpected nrecv") + } + }) + + t.Run("on DNS failure", func(t *testing.T) { + expected := errors.New(netxlite.FailureDNSNXDOMAINError) + underlying := &mocks.Resolver{ + MockLookupHTTPS: func(ctx context.Context, domain string) (*model.HTTPSSvc, error) { + return nil, expected + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupHTTPS(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if got != nil { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 10 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 128 { + t.Fatal("unexpected nrecv") + } + }) + }) + + t.Run("LookupNS works as intended", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + underlying := &mocks.Resolver{ + MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) { + out := make([]*net.NS, 3) + return out, nil + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupNS(ctx, "dns.google") + if err != nil { + t.Fatal("unexpected error", err) + } + if len(got) != 3 { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 10 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 256 { + t.Fatal("unexpected nrecv") + } + }) + + t.Run("on non-DNS failure", func(t *testing.T) { + expected := errors.New("mocked error") + underlying := &mocks.Resolver{ + MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) { + return nil, expected + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupNS(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if len(got) != 0 { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 10 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 0 { + t.Fatal("unexpected nrecv") + } + }) + + t.Run("on DNS failure", func(t *testing.T) { + expected := errors.New(netxlite.FailureDNSNXDOMAINError) + underlying := &mocks.Resolver{ + MockLookupNS: func(ctx context.Context, domain string) ([]*net.NS, error) { + return nil, expected + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupNS(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if len(got) != 0 { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 10 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 128 { + t.Fatal("unexpected nrecv") + } + }) + }) + + t.Run("LookupHost works as intended", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + underlying := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + out := make([]string, 3) + return out, nil + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupHost(ctx, "dns.google") + if err != nil { + t.Fatal("unexpected error", err) + } + if len(got) != 3 { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 20 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 256 { + t.Fatal("unexpected nrecv") + } + }) + + t.Run("on non-DNS failure", func(t *testing.T) { + expected := errors.New("mocked error") + underlying := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupHost(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if len(got) != 0 { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 20 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 0 { + t.Fatal("unexpected nrecv") + } + }) + + t.Run("on DNS failure", func(t *testing.T) { + expected := errors.New(netxlite.FailureDNSNXDOMAINError) + underlying := &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, expected + }, + } + counter := New() + reso := WrapWithContextAwareSystemResolver(underlying) + ctx := WithSessionByteCounter(context.Background(), counter) + got, err := reso.LookupHost(ctx, "dns.google") + if !errors.Is(err, expected) { + t.Fatal("unexpected error", err) + } + if len(got) != 0 { + t.Fatal("invalid result") + } + if nsent := counter.BytesSent(); nsent != 20 { + t.Fatal("unexpected nsent", nsent) + } + if nrecv := counter.BytesReceived(); nrecv != 128 { + t.Fatal("unexpected nrecv") + } + }) + }) + + t.Run("Network works as intended", func(t *testing.T) { + underlying := &mocks.Resolver{ + MockNetwork: func() string { + return "udp" + }, + } + reso := WrapWithContextAwareSystemResolver(underlying) + if reso.Network() != "udp" { + t.Fatal("unexpected result") + } + }) +} diff --git a/internal/measurexlite/bytecounting_test.go b/internal/measurexlite/bytecounting_test.go new file mode 100644 index 0000000000..e2e779ccfa --- /dev/null +++ b/internal/measurexlite/bytecounting_test.go @@ -0,0 +1,125 @@ +package measurexlite_test + +import ( + "context" + "crypto/tls" + "testing" + "time" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/bytecounter" + "github.com/ooni/probe-cli/v3/internal/measurexlite" +) + +func TestCountSystemResolverBytes(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // create the trace + tx := measurexlite.NewTrace(0, time.Now()) + + // create the context + ctx := context.Background() + + // add and register session byte counter + sbc := bytecounter.New() + ctx = bytecounter.WithSessionByteCounter(ctx, sbc) + + // add and register experiment byte counter + ebc := bytecounter.New() + ctx = bytecounter.WithExperimentByteCounter(ctx, ebc) + + // create system resolver + reso := tx.NewStdlibResolver(log.Log) + defer reso.CloseIdleConnections() + + // run a lookup + addrs, err := reso.LookupHost(ctx, "www.example.com") + + // make sure we didn't fail + if err != nil { + t.Fatal(err) + } + + // make sure we resolved addresses + if len(addrs) <= 0 { + t.Fatal("expected at least one address") + } + + // make sure we received something + if sbc.Received.Load() <= 0 { + t.Fatal("sbs.Received.Load() returned zero or less") + } + if ebc.Received.Load() <= 0 { + t.Fatal("ebc.Received.Load() returned zero or less") + } + + // make sure we send something + if sbc.Sent.Load() <= 0 { + t.Fatal("sbs.Sent.Load() returned zero or less") + } + if ebc.Sent.Load() <= 0 { + t.Fatal("ebc.Sent.Load() returned zero or less") + } +} + +func TestCountConnBytes(t *testing.T) { + if testing.Short() { + t.Skip("skip test in short mode") + } + + // create the trace + tx := measurexlite.NewTrace(0, time.Now()) + + // create the context + ctx := context.Background() + + // add and register session byte counter + sbc := bytecounter.New() + ctx = bytecounter.WithSessionByteCounter(ctx, sbc) + + // add and register experiment byte counter + ebc := bytecounter.New() + ctx = bytecounter.WithExperimentByteCounter(ctx, ebc) + + // create dialer + reso := tx.NewDialerWithoutResolver(log.Log) + + // run a lookup + conn, err := reso.DialContext(ctx, "tcp", "8.8.8.8:443") + defer measurexlite.MaybeClose(conn) + + // make sure we didn't fail + if err != nil { + t.Fatal(err) + } + + // create the handshaker + thx := tx.NewTLSHandshakerStdlib(log.Log) + + // handshake + tconn, err := thx.Handshake(ctx, conn, &tls.Config{ServerName: "dns.google"}) + defer measurexlite.MaybeClose(tconn) + + // make sure we didn't fail + if err != nil { + t.Fatal(err) + } + + // make sure we received something + if sbc.Received.Load() <= 0 { + t.Fatal("sbs.Received.Load() returned zero or less") + } + if ebc.Received.Load() <= 0 { + t.Fatal("ebc.Received.Load() returned zero or less") + } + + // make sure we send something + if sbc.Sent.Load() <= 0 { + t.Fatal("sbs.Sent.Load() returned zero or less") + } + if ebc.Sent.Load() <= 0 { + t.Fatal("ebc.Sent.Load() returned zero or less") + } +} diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go index 18b37c2c33..6cd7ca1adf 100644 --- a/internal/measurexlite/dialer.go +++ b/internal/measurexlite/dialer.go @@ -12,6 +12,7 @@ import ( "strconv" "time" + "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -38,7 +39,9 @@ var _ model.Dialer = &dialerTrace{} // DialContext implements model.Dialer.DialContext. func (d *dialerTrace) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return d.d.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address) + // Here we make sure that we're counting bytes sent and received. + dialer := bytecounter.WrapWithContextAwareDialer(d.d) + return dialer.DialContext(netxlite.ContextWithTrace(ctx, d.tx), network, address) } // CloseIdleConnections implements model.Dialer.CloseIdleConnections. diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index f450f33673..662942089d 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -12,6 +12,7 @@ import ( "time" "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/geoipx" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" @@ -93,7 +94,8 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS, // NewStdlibResolver returns a trace-ware system resolver func (tx *Trace) NewStdlibResolver(logger model.DebugLogger) model.Resolver { - return tx.wrapResolver(tx.Netx.NewStdlibResolver(logger)) + // Here we make sure that we're counting bytes sent and received. + return bytecounter.WrapWithContextAwareSystemResolver(tx.wrapResolver(tx.Netx.NewStdlibResolver(logger))) } // NewParallelUDPResolver returns a trace-ware parallel UDP resolver diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index 7b2cd48886..7596b53eb4 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -9,6 +9,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/miekg/dns" + "github.com/ooni/probe-cli/v3/internal/bytecounter" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" @@ -328,7 +329,8 @@ func TestNewWrappedResolvers(t *testing.T) { zeroTime := time.Now() trace := NewTrace(0, zeroTime) resolver := trace.NewStdlibResolver(model.DiscardLogger) - resolvert := resolver.(*resolverTrace) + resolverbyteaware := resolver.(*bytecounter.ContextAwareSystemResolver) + resolvert := resolverbyteaware.R.(*resolverTrace) if resolvert.tx != trace { t.Fatal("invalid trace") } diff --git a/internal/measurexlite/quic.go b/internal/measurexlite/quic.go index 7b089de159..03de1b5bd8 100644 --- a/internal/measurexlite/quic.go +++ b/internal/measurexlite/quic.go @@ -39,6 +39,7 @@ var _ model.QUICDialer = &quicDialerTrace{} func (qdx *quicDialerTrace) DialContext(ctx context.Context, address string, tlsConfig *tls.Config, quicConfig *quic.Config) ( quic.EarlyConnection, error) { + // TODO(https://github.com/ooni/probe/issues/2665) return qdx.qd.DialContext(netxlite.ContextWithTrace(ctx, qdx.tx), address, tlsConfig, quicConfig) }