From d3274c62eaf9efa2f3ecb84186419627bbb53e93 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Tue, 12 Sep 2023 15:46:41 +0200 Subject: [PATCH] refactor(measurexlite): depend on model.MeasuringNetwork (#1260) With this diff, we detach measurexlite from netxlite. It was already quite detached through functions used for testing. However, the changes we implement are allow us to test measurexlite code by changing the .Netx field of a *Trace, which means we can avoid using netxlite's singleton for new code. Another benefit of this diff is that we have clearly spelled out and packaged into an interface the dependencies required to perform measurements with measurexlite. Therefore, we can gracefully continue separating the code used for measuring from the code for contacting the backend, as detailed in https://github.com/ooni/probe/issues/2531. --- internal/measurexlite/dialer.go | 2 +- internal/measurexlite/dialer_test.go | 18 +- internal/measurexlite/dns.go | 6 +- internal/measurexlite/quic.go | 2 +- internal/measurexlite/quic_test.go | 18 +- internal/measurexlite/tls.go | 2 +- internal/measurexlite/tls_test.go | 12 +- internal/measurexlite/trace.go | 119 +----- internal/measurexlite/trace_test.go | 557 ++++++++------------------- internal/measurexlite/utls.go | 2 +- internal/measurexlite/utls_test.go | 6 +- 11 files changed, 226 insertions(+), 518 deletions(-) diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go index 32b9cdb194..663b88915f 100644 --- a/internal/measurexlite/dialer.go +++ b/internal/measurexlite/dialer.go @@ -20,7 +20,7 @@ import ( // except that it returns a model.Dialer that uses this trace. func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer { return &dialerTrace{ - d: tx.newDialerWithoutResolver(dl), + d: tx.Netx.NewDialerWithoutResolver(dl), tx: tx, } } diff --git a/internal/measurexlite/dialer_test.go b/internal/measurexlite/dialer_test.go index d1f108447a..0137c91b35 100644 --- a/internal/measurexlite/dialer_test.go +++ b/internal/measurexlite/dialer_test.go @@ -21,8 +21,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { underlying := &mocks.Dialer{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return underlying + }, } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) dt := dialer.(*dialerTrace) @@ -46,8 +48,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { return nil, expectedErr }, } - trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return underlying + }, } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) ctx := context.Background() @@ -72,8 +76,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { called = true }, } - trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return underlying + }, } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) dialer.CloseIdleConnections() diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 076166565a..4d214a3d52 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -93,17 +93,17 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS, // NewStdlibResolver returns a trace-ware system resolver func (tx *Trace) NewStdlibResolver(logger model.Logger) model.Resolver { - return tx.wrapResolver(tx.newStdlibResolver(logger)) + return tx.wrapResolver(tx.Netx.NewStdlibResolver(logger)) } // NewParallelUDPResolver returns a trace-ware parallel UDP resolver func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - return tx.wrapResolver(tx.newParallelUDPResolver(logger, dialer, address)) + return tx.wrapResolver(tx.Netx.NewParallelUDPResolver(logger, dialer, address)) } // NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { - return tx.wrapResolver(tx.newParallelDNSOverHTTPSResolver(logger, URL)) + return tx.wrapResolver(tx.Netx.NewParallelDNSOverHTTPSResolver(logger, URL)) } // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost diff --git a/internal/measurexlite/quic.go b/internal/measurexlite/quic.go index d75aed649c..fa21d05b27 100644 --- a/internal/measurexlite/quic.go +++ b/internal/measurexlite/quic.go @@ -18,7 +18,7 @@ import ( // except that it returns a model.QUICDialer that uses this trace. func (tx *Trace) NewQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { return &quicDialerTrace{ - qd: tx.newQUICDialerWithoutResolver(listener, dl), + qd: tx.Netx.NewQUICDialerWithoutResolver(listener, dl), tx: tx, } } diff --git a/internal/measurexlite/quic_test.go b/internal/measurexlite/quic_test.go index 78083643e2..24d0a6f857 100644 --- a/internal/measurexlite/quic_test.go +++ b/internal/measurexlite/quic_test.go @@ -23,8 +23,10 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { underlying := &mocks.QUICDialer{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newQUICDialerWithoutResolverFn = func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return underlying + }, } listener := &mocks.UDPListener{} dialer := trace.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) @@ -50,8 +52,10 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { return nil, expectedErr }, } - trace.newQUICDialerWithoutResolverFn = func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return underlying + }, } listener := &mocks.UDPListener{} dialer := trace.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) @@ -77,8 +81,10 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { called = true }, } - trace.newQUICDialerWithoutResolverFn = func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return underlying + }, } listener := &mocks.UDPListener{} dialer := trace.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go index 3c30cac3c6..be29d6cc54 100644 --- a/internal/measurexlite/tls.go +++ b/internal/measurexlite/tls.go @@ -20,7 +20,7 @@ import ( // except that it returns a model.TLSHandshaker that uses this trace. func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { return &tlsHandshakerTrace{ - thx: tx.newTLSHandshakerStdlib(dl), + thx: tx.Netx.NewTLSHandshakerStdlib(dl), tx: tx, } } diff --git a/internal/measurexlite/tls_test.go b/internal/measurexlite/tls_test.go index 02a38faae8..5d98349f4d 100644 --- a/internal/measurexlite/tls_test.go +++ b/internal/measurexlite/tls_test.go @@ -24,8 +24,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { underlying := &mocks.TLSHandshaker{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return underlying + }, } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) thxt := thx.(*tlsHandshakerTrace) @@ -49,8 +51,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { return nil, tls.ConnectionState{}, expectedErr }, } - trace.newTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return underlying + }, } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) ctx := context.Background() diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 2ae953125d..84bf672632 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -10,7 +10,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" - utls "gitlab.com/yawning/utls.git" ) // Trace implements [model.Trace]. We use a [context.Context] to register ourselves @@ -31,6 +30,11 @@ type Trace struct { // once you have constructed a trace MAY lead to data races. Index int64 + // Netx is the network to use for measuring. The constructor inits this + // field using a [*netxlite.Netx]. You MAY override this field for testing. Make + // sure you do that before you start measuring to avoid data races. + Netx model.MeasuringNetwork + // bytesReceivedMap maps a remote host with the bytes we received // from such a remote host. Accessing this map requires one to // additionally hold the bytesReceivedMu mutex. @@ -40,43 +44,15 @@ type Trace struct { // access from multiple goroutines. bytesReceivedMu *sync.Mutex - // networkEvent is MANDATORY and buffers network events. - networkEvent chan *model.ArchivalNetworkEvent - - // newStdlibResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewStdlibResolver factory. - newStdlibResolverFn func(logger model.Logger) model.Resolver - - // newParallelUDPResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewParallelUDPResolver factory. - newParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver - - // newParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewParallelDNSOverHTTPSUDPResolver factory. - newParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver - - // newDialerWithoutResolverFn is OPTIONAL and can be used to override - // calls to the netxlite.NewDialerWithoutResolver factory. - newDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer - - // newTLSHandshakerStdlibFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewTLSHandshakerStdlib factory. - newTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker - - // newTLSHandshakerUTLSFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewTLSHandshakerUTLS factory. - newTLSHandshakerUTLSFn func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker - - // NewDialerWithoutResolverFn is OPTIONAL and can be used to override - // calls to the netxlite.NewQUICDialerWithoutResolver factory. - newQUICDialerWithoutResolverFn func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer - // dnsLookup is MANDATORY and buffers DNS Lookup observations. dnsLookup chan *model.ArchivalDNSLookupResult // delayedDNSResponse is MANDATORY and buffers delayed DNS responses. delayedDNSResponse chan *model.ArchivalDNSLookupResult + // networkEvent is MANDATORY and buffers network events. + networkEvent chan *model.ArchivalNetworkEvent + // tcpConnect is MANDATORY and buffers TCP connect observations. tcpConnect chan *model.ArchivalTCPConnectResult @@ -134,19 +110,9 @@ const QUICHandshakeBufferSize = 8 func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { return &Trace{ Index: index, + Netx: &netxlite.Netx{Underlying: nil}, // use the host network bytesReceivedMap: make(map[string]int64), bytesReceivedMu: &sync.Mutex{}, - networkEvent: make( - chan *model.ArchivalNetworkEvent, - NetworkEventBufferSize, - ), - newStdlibResolverFn: nil, // use default - newParallelUDPResolverFn: nil, // use default - newParallelDNSOverHTTPSResolverFn: nil, // use default - newDialerWithoutResolverFn: nil, // use default - newTLSHandshakerStdlibFn: nil, // use default - newTLSHandshakerUTLSFn: nil, // use default - newQUICDialerWithoutResolverFn: nil, // use default dnsLookup: make( chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize, @@ -155,6 +121,10 @@ func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { chan *model.ArchivalDNSLookupResult, DelayedDNSResponseBufferSize, ), + networkEvent: make( + chan *model.ArchivalNetworkEvent, + NetworkEventBufferSize, + ), tcpConnect: make( chan *model.ArchivalTCPConnectResult, TCPConnectBufferSize, @@ -173,69 +143,6 @@ func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { } } -// newStdlibResolver indirectly calls the passed netxlite.NewStdlibResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newStdlibResolver(logger model.Logger) model.Resolver { - if tx.newStdlibResolverFn != nil { - return tx.newStdlibResolverFn(logger) - } - return netxlite.NewStdlibResolver(logger) -} - -// newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - if tx.newParallelUDPResolverFn != nil { - return tx.newParallelUDPResolverFn(logger, dialer, address) - } - return netxlite.NewParallelUDPResolver(logger, dialer, address) -} - -// newParallelDNSOverHTTPSResolver indirectly calls the passed netxlite.NewParallerDNSOverHTTPSResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { - if tx.newParallelDNSOverHTTPSResolverFn != nil { - return tx.newParallelDNSOverHTTPSResolverFn(logger, URL) - } - return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL) -} - -// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver -// thus allowing us to mock this func for testing. -func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { - if tx.newDialerWithoutResolverFn != nil { - return tx.newDialerWithoutResolverFn(dl) - } - return netxlite.NewDialerWithoutResolver(dl) -} - -// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib -// thus allowing us to mock this func for testing. -func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { - if tx.newTLSHandshakerStdlibFn != nil { - return tx.newTLSHandshakerStdlibFn(dl) - } - return netxlite.NewTLSHandshakerStdlib(dl) -} - -// newTLSHandshakerUTLS indirectly calls netxlite.NewTLSHandshakerUTLS -// thus allowing us to mock this func for testing. -func (tx *Trace) newTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - if tx.newTLSHandshakerUTLSFn != nil { - return tx.newTLSHandshakerUTLSFn(dl, id) - } - return netxlite.NewTLSHandshakerUTLS(dl, id) -} - -// newQUICDialerWithoutResolver indirectly calls netxlite.NewQUICDialerWithoutResolver -// thus allowing us to mock this func for testing. -func (tx *Trace) newQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - if tx.newQUICDialerWithoutResolverFn != nil { - return tx.newQUICDialerWithoutResolverFn(listener, dl) - } - return netxlite.NewQUICDialerWithoutResolver(listener, dl) -} - // TimeNow implements model.Trace.TimeNow. func (tx *Trace) TimeNow() time.Time { if tx.timeNowFn != nil { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index e7950f93ad..1b85d97979 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -50,45 +50,16 @@ func TestNewTrace(t *testing.T) { } }) - t.Run("NewStdlibResolverFn is nil", func(t *testing.T) { - if trace.newStdlibResolverFn != nil { - t.Fatal("expected nil NewStdlibResolverFn") + t.Run("Netx is an instance of *netxlite.Netx with a nil .Underlying", func(t *testing.T) { + if trace.Netx == nil { + t.Fatal("expected non-nil .Netx") } - }) - - t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) { - if trace.newParallelUDPResolverFn != nil { - t.Fatal("expected nil NewParallelUDPResolverFn") + netx, good := trace.Netx.(*netxlite.Netx) + if !good { + t.Fatal("not a *netxlite.Netx") } - }) - - t.Run("NewParallelDNSOverHTTPSResolverFn is nil", func(t *testing.T) { - if trace.newParallelDNSOverHTTPSResolverFn != nil { - t.Fatal("expected nil NewParallelDNSOverHTTPSResolverFn") - } - }) - - t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) { - if trace.newDialerWithoutResolverFn != nil { - t.Fatal("expected nil NewDialerWithoutResolverFn") - } - }) - - t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) { - if trace.newTLSHandshakerStdlibFn != nil { - t.Fatal("expected nil NewTLSHandshakerStdlibFn") - } - }) - - t.Run("newTLShandshakerUTLSFn is nil", func(t *testing.T) { - if trace.newTLSHandshakerUTLSFn != nil { - t.Fatal("expected nil NewTLSHandshakerUTLSfn") - } - }) - - t.Run("NewQUICDialerWithoutResolverFn is nil", func(t *testing.T) { - if trace.newQUICDialerWithoutResolverFn != nil { - t.Fatal("expected nil NewQUICDialerQithoutResolverFn") + if netx.Underlying != nil { + t.Fatal(".Underlying is not nil") } }) @@ -202,34 +173,10 @@ func TestNewTrace(t *testing.T) { } func TestTrace(t *testing.T) { - t.Run("NewStdlibResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newStdlibResolverFn: func(logger model.Logger) model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{}, mockedErr - }, - } - }, - } - resolver := tx.newStdlibResolver(model.DiscardLogger) - ctx := context.Background() - addrs, err := resolver.LookupHost(ctx, "example.com") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) - + t.Run("NewStdlibResolver works as intended", func(t *testing.T) { t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newParallelUDPResolverFn: nil, - } - resolver := tx.newStdlibResolver(model.DiscardLogger) + tx := NewTrace(0, time.Now()) + resolver := tx.NewStdlibResolver(model.DiscardLogger) ctx, cancel := context.WithCancel(context.Background()) cancel() addrs, err := resolver.LookupHost(ctx, "example.com") @@ -242,339 +189,175 @@ func TestTrace(t *testing.T) { }) }) - t.Run("NewParallelUDPResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{}, mockedErr - }, - } - }, - } - dialer := &mocks.Dialer{} - resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") - ctx := context.Background() - addrs, err := resolver.LookupHost(ctx, "example.com") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) - - t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newParallelUDPResolverFn: nil, - } - dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) - resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") - ctx, cancel := context.WithCancel(context.Background()) - cancel() - addrs, err := resolver.LookupHost(ctx, "example.com") - if err == nil || err.Error() != netxlite.FailureInterrupted { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) + t.Run("NewParallelUDPResolver works as intended", func(t *testing.T) { + tx := NewTrace(0, time.Now()) + dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) + resolver := tx.NewParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + addrs, err := resolver.LookupHost(ctx, "example.com") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } }) - t.Run("NewParallelDNSOverHTTPSResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{}, mockedErr - }, - } - }, - } - resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") - ctx := context.Background() - addrs, err := resolver.LookupHost(ctx, "example.com") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) - - t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newParallelDNSOverHTTPSResolverFn: nil, - } - resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") - ctx, cancel := context.WithCancel(context.Background()) - cancel() - addrs, err := resolver.LookupHost(ctx, "example.com") - if err == nil || err.Error() != netxlite.FailureInterrupted { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) + t.Run("NewParallelDNSOverHTTPSResolver works as intended", func(t *testing.T) { + tx := NewTrace(0, time.Now()) + resolver := tx.NewParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + addrs, err := resolver.LookupHost(ctx, "example.com") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } }) - t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer { - return &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, mockedErr - }, - } - }, - } - dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) - ctx := context.Background() - conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newDialerWithoutResolverFn: nil, - } - dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // fail immediately - conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") - if err == nil || err.Error() != netxlite.FailureInterrupted { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewDialerWithoutResolver works as intended", func(t *testing.T) { + tx := NewTrace(0, time.Now()) + dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker { - return &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, mockedErr - }, - } - }, - } - thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerStdlibFn: nil, - } - thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) - tcpConn := &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - MockRemoteAddr: func() net.Addr { - return &mocks.Addr{ - MockNetwork: func() string { - return "tcp" - }, - MockString: func() string { - return "1.1.1.1:443" - }, - } - }, - MockWrite: func(b []byte) (int, error) { - return 0, mockedErr - }, - MockClose: func() error { - return nil - }, - } - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewTLSHandshakerStdlib works as intended", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := NewTrace(0, time.Now()) + thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - t.Run("NewTLSHandshakerUTLSFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerUTLSFn: func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - return &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, mockedErr - }, - } - }, - } - thx := tx.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerStdlibFn: nil, - } - thx := tx.newTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) - tcpConn := &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - MockRemoteAddr: func() net.Addr { - return &mocks.Addr{ - MockNetwork: func() string { - return "tcp" - }, - MockString: func() string { - return "1.1.1.1:443" - }, - } - }, - MockWrite: func(b []byte) (int, error) { - return 0, mockedErr - }, - MockClose: func() error { - return nil - }, - } - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewTLSHandshakerUTLS works as intended", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := NewTrace(0, time.Now()) + thx := tx.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - t.Run("NewQUICDialerWithoutResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newQUICDialerWithoutResolverFn: func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return &mocks.QUICDialer{ - MockDialContext: func(ctx context.Context, address string, - tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { - return nil, mockedErr - }, - } - }, - } - qdx := tx.newQUICDialerWithoutResolver(&mocks.UDPListener{}, model.DiscardLogger) - ctx := context.Background() - qconn, err := qdx.DialContext(ctx, "1.1.1.1:443", &tls.Config{}, &quic.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if qconn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newQUICDialerWithoutResolverFn: nil, // explicit - } - pconn := &mocks.UDPLikeConn{ - MockLocalAddr: func() net.Addr { - return &net.UDPAddr{ - // quic-go does not allow the use of the same net.PacketConn for multiple "Dial" - // calls (unless a quic.Transport is used), so we have to make sure to mock local - // addresses with different ports, as tests run in parallel. - Port: 0, - } - }, - MockRemoteAddr: func() net.Addr { - return &net.UDPAddr{ - Port: 0, - } - }, - MockSyscallConn: func() (syscall.RawConn, error) { - return nil, mockedErr - }, - MockClose: func() error { - return nil - }, - MockSetReadBuffer: func(n int) error { - return nil - }, - } - listener := &mocks.UDPListener{ - MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { - return pconn, nil - }, - } - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - dialer := tx.newQUICDialerWithoutResolver(listener, model.DiscardLogger) - ctx := context.Background() - qconn, err := dialer.DialContext(ctx, "1.1.1.1:443", tlsConfig, &quic.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if qconn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewQUICDialerWithoutResolver works as intended", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := NewTrace(0, time.Now()) + pconn := &mocks.UDPLikeConn{ + MockLocalAddr: func() net.Addr { + return &net.UDPAddr{ + // quic-go does not allow the use of the same net.PacketConn for multiple "Dial" + // calls (unless a quic.Transport is used), so we have to make sure to mock local + // addresses with different ports, as tests run in parallel. + Port: 0, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.UDPAddr{ + Port: 0, + } + }, + MockSyscallConn: func() (syscall.RawConn, error) { + return nil, mockedErr + }, + MockClose: func() error { + return nil + }, + MockSetReadBuffer: func(n int) error { + return nil + }, + } + listener := &mocks.UDPListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return pconn, nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + dialer := tx.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) + ctx := context.Background() + qconn, err := dialer.DialContext(ctx, "1.1.1.1:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if qconn != nil { + t.Fatal("expected nil conn") + } }) t.Run("TimeNowFn works as intended", func(t *testing.T) { diff --git a/internal/measurexlite/utls.go b/internal/measurexlite/utls.go index 747c097972..d00016006b 100644 --- a/internal/measurexlite/utls.go +++ b/internal/measurexlite/utls.go @@ -13,7 +13,7 @@ import ( // except that it returns a model.TLSHandshaker that uses this trace. func (tx *Trace) NewTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { return &tlsHandshakerTrace{ - thx: tx.newTLSHandshakerUTLS(dl, id), + thx: tx.Netx.NewTLSHandshakerUTLS(dl, id), tx: tx, } } diff --git a/internal/measurexlite/utls_test.go b/internal/measurexlite/utls_test.go index 2dccd89f0a..959aff0657 100644 --- a/internal/measurexlite/utls_test.go +++ b/internal/measurexlite/utls_test.go @@ -14,8 +14,10 @@ func TestNewTLSHandshakerUTLS(t *testing.T) { underlying := &mocks.TLSHandshaker{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newTLSHandshakerUTLSFn = func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewTLSHandshakerUTLS: func(logger model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + return underlying + }, } thx := trace.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) thxt := thx.(*tlsHandshakerTrace)