diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go index 32b9cdb194..663b88915f 100644 --- a/internal/measurexlite/dialer.go +++ b/internal/measurexlite/dialer.go @@ -20,7 +20,7 @@ import ( // except that it returns a model.Dialer that uses this trace. func (tx *Trace) NewDialerWithoutResolver(dl model.DebugLogger) model.Dialer { return &dialerTrace{ - d: tx.newDialerWithoutResolver(dl), + d: tx.Netx.NewDialerWithoutResolver(dl), tx: tx, } } diff --git a/internal/measurexlite/dialer_test.go b/internal/measurexlite/dialer_test.go index d1f108447a..0137c91b35 100644 --- a/internal/measurexlite/dialer_test.go +++ b/internal/measurexlite/dialer_test.go @@ -21,8 +21,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { underlying := &mocks.Dialer{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return underlying + }, } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) dt := dialer.(*dialerTrace) @@ -46,8 +48,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { return nil, expectedErr }, } - trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return underlying + }, } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) ctx := context.Background() @@ -72,8 +76,10 @@ func TestNewDialerWithoutResolver(t *testing.T) { called = true }, } - trace.newDialerWithoutResolverFn = func(dl model.DebugLogger) model.Dialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return underlying + }, } dialer := trace.NewDialerWithoutResolver(model.DiscardLogger) dialer.CloseIdleConnections() diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 076166565a..4d214a3d52 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -93,17 +93,17 @@ func (r *resolverTrace) LookupNS(ctx context.Context, domain string) ([]*net.NS, // NewStdlibResolver returns a trace-ware system resolver func (tx *Trace) NewStdlibResolver(logger model.Logger) model.Resolver { - return tx.wrapResolver(tx.newStdlibResolver(logger)) + return tx.wrapResolver(tx.Netx.NewStdlibResolver(logger)) } // NewParallelUDPResolver returns a trace-ware parallel UDP resolver func (tx *Trace) NewParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - return tx.wrapResolver(tx.newParallelUDPResolver(logger, dialer, address)) + return tx.wrapResolver(tx.Netx.NewParallelUDPResolver(logger, dialer, address)) } // NewParallelDNSOverHTTPSResolver returns a trace-aware parallel DoH resolver func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { - return tx.wrapResolver(tx.newParallelDNSOverHTTPSResolver(logger, URL)) + return tx.wrapResolver(tx.Netx.NewParallelDNSOverHTTPSResolver(logger, URL)) } // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost diff --git a/internal/measurexlite/quic.go b/internal/measurexlite/quic.go index d75aed649c..fa21d05b27 100644 --- a/internal/measurexlite/quic.go +++ b/internal/measurexlite/quic.go @@ -18,7 +18,7 @@ import ( // except that it returns a model.QUICDialer that uses this trace. func (tx *Trace) NewQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { return &quicDialerTrace{ - qd: tx.newQUICDialerWithoutResolver(listener, dl), + qd: tx.Netx.NewQUICDialerWithoutResolver(listener, dl), tx: tx, } } diff --git a/internal/measurexlite/quic_test.go b/internal/measurexlite/quic_test.go index 78083643e2..24d0a6f857 100644 --- a/internal/measurexlite/quic_test.go +++ b/internal/measurexlite/quic_test.go @@ -23,8 +23,10 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { underlying := &mocks.QUICDialer{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newQUICDialerWithoutResolverFn = func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return underlying + }, } listener := &mocks.UDPListener{} dialer := trace.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) @@ -50,8 +52,10 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { return nil, expectedErr }, } - trace.newQUICDialerWithoutResolverFn = func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return underlying + }, } listener := &mocks.UDPListener{} dialer := trace.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) @@ -77,8 +81,10 @@ func TestNewQUICDialerWithoutResolver(t *testing.T) { called = true }, } - trace.newQUICDialerWithoutResolverFn = func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return underlying + }, } listener := &mocks.UDPListener{} dialer := trace.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go index 3c30cac3c6..be29d6cc54 100644 --- a/internal/measurexlite/tls.go +++ b/internal/measurexlite/tls.go @@ -20,7 +20,7 @@ import ( // except that it returns a model.TLSHandshaker that uses this trace. func (tx *Trace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { return &tlsHandshakerTrace{ - thx: tx.newTLSHandshakerStdlib(dl), + thx: tx.Netx.NewTLSHandshakerStdlib(dl), tx: tx, } } diff --git a/internal/measurexlite/tls_test.go b/internal/measurexlite/tls_test.go index 02a38faae8..5d98349f4d 100644 --- a/internal/measurexlite/tls_test.go +++ b/internal/measurexlite/tls_test.go @@ -24,8 +24,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { underlying := &mocks.TLSHandshaker{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return underlying + }, } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) thxt := thx.(*tlsHandshakerTrace) @@ -49,8 +51,10 @@ func TestNewTLSHandshakerStdlib(t *testing.T) { return nil, tls.ConnectionState{}, expectedErr }, } - trace.newTLSHandshakerStdlibFn = func(dl model.DebugLogger) model.TLSHandshaker { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return underlying + }, } thx := trace.NewTLSHandshakerStdlib(model.DiscardLogger) ctx := context.Background() diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 2ae953125d..84bf672632 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -10,7 +10,6 @@ import ( "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" - utls "gitlab.com/yawning/utls.git" ) // Trace implements [model.Trace]. We use a [context.Context] to register ourselves @@ -31,6 +30,11 @@ type Trace struct { // once you have constructed a trace MAY lead to data races. Index int64 + // Netx is the network to use for measuring. The constructor inits this + // field using a [*netxlite.Netx]. You MAY override this field for testing. Make + // sure you do that before you start measuring to avoid data races. + Netx model.MeasuringNetwork + // bytesReceivedMap maps a remote host with the bytes we received // from such a remote host. Accessing this map requires one to // additionally hold the bytesReceivedMu mutex. @@ -40,43 +44,15 @@ type Trace struct { // access from multiple goroutines. bytesReceivedMu *sync.Mutex - // networkEvent is MANDATORY and buffers network events. - networkEvent chan *model.ArchivalNetworkEvent - - // newStdlibResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewStdlibResolver factory. - newStdlibResolverFn func(logger model.Logger) model.Resolver - - // newParallelUDPResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewParallelUDPResolver factory. - newParallelUDPResolverFn func(logger model.Logger, dialer model.Dialer, address string) model.Resolver - - // newParallelDNSOverHTTPSResolverFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewParallelDNSOverHTTPSUDPResolver factory. - newParallelDNSOverHTTPSResolverFn func(logger model.Logger, URL string) model.Resolver - - // newDialerWithoutResolverFn is OPTIONAL and can be used to override - // calls to the netxlite.NewDialerWithoutResolver factory. - newDialerWithoutResolverFn func(dl model.DebugLogger) model.Dialer - - // newTLSHandshakerStdlibFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewTLSHandshakerStdlib factory. - newTLSHandshakerStdlibFn func(dl model.DebugLogger) model.TLSHandshaker - - // newTLSHandshakerUTLSFn is OPTIONAL and can be used to overide - // calls to the netxlite.NewTLSHandshakerUTLS factory. - newTLSHandshakerUTLSFn func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker - - // NewDialerWithoutResolverFn is OPTIONAL and can be used to override - // calls to the netxlite.NewQUICDialerWithoutResolver factory. - newQUICDialerWithoutResolverFn func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer - // dnsLookup is MANDATORY and buffers DNS Lookup observations. dnsLookup chan *model.ArchivalDNSLookupResult // delayedDNSResponse is MANDATORY and buffers delayed DNS responses. delayedDNSResponse chan *model.ArchivalDNSLookupResult + // networkEvent is MANDATORY and buffers network events. + networkEvent chan *model.ArchivalNetworkEvent + // tcpConnect is MANDATORY and buffers TCP connect observations. tcpConnect chan *model.ArchivalTCPConnectResult @@ -134,19 +110,9 @@ const QUICHandshakeBufferSize = 8 func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { return &Trace{ Index: index, + Netx: &netxlite.Netx{Underlying: nil}, // use the host network bytesReceivedMap: make(map[string]int64), bytesReceivedMu: &sync.Mutex{}, - networkEvent: make( - chan *model.ArchivalNetworkEvent, - NetworkEventBufferSize, - ), - newStdlibResolverFn: nil, // use default - newParallelUDPResolverFn: nil, // use default - newParallelDNSOverHTTPSResolverFn: nil, // use default - newDialerWithoutResolverFn: nil, // use default - newTLSHandshakerStdlibFn: nil, // use default - newTLSHandshakerUTLSFn: nil, // use default - newQUICDialerWithoutResolverFn: nil, // use default dnsLookup: make( chan *model.ArchivalDNSLookupResult, DNSLookupBufferSize, @@ -155,6 +121,10 @@ func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { chan *model.ArchivalDNSLookupResult, DelayedDNSResponseBufferSize, ), + networkEvent: make( + chan *model.ArchivalNetworkEvent, + NetworkEventBufferSize, + ), tcpConnect: make( chan *model.ArchivalTCPConnectResult, TCPConnectBufferSize, @@ -173,69 +143,6 @@ func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { } } -// newStdlibResolver indirectly calls the passed netxlite.NewStdlibResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newStdlibResolver(logger model.Logger) model.Resolver { - if tx.newStdlibResolverFn != nil { - return tx.newStdlibResolverFn(logger) - } - return netxlite.NewStdlibResolver(logger) -} - -// newParallelUDPResolver indirectly calls the passed netxlite.NewParallerUDPResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newParallelUDPResolver(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - if tx.newParallelUDPResolverFn != nil { - return tx.newParallelUDPResolverFn(logger, dialer, address) - } - return netxlite.NewParallelUDPResolver(logger, dialer, address) -} - -// newParallelDNSOverHTTPSResolver indirectly calls the passed netxlite.NewParallerDNSOverHTTPSResolver -// thus allowing us to mock this function for testing -func (tx *Trace) newParallelDNSOverHTTPSResolver(logger model.Logger, URL string) model.Resolver { - if tx.newParallelDNSOverHTTPSResolverFn != nil { - return tx.newParallelDNSOverHTTPSResolverFn(logger, URL) - } - return netxlite.NewParallelDNSOverHTTPSResolver(logger, URL) -} - -// newDialerWithoutResolver indirectly calls netxlite.NewDialerWithoutResolver -// thus allowing us to mock this func for testing. -func (tx *Trace) newDialerWithoutResolver(dl model.DebugLogger) model.Dialer { - if tx.newDialerWithoutResolverFn != nil { - return tx.newDialerWithoutResolverFn(dl) - } - return netxlite.NewDialerWithoutResolver(dl) -} - -// newTLSHandshakerStdlib indirectly calls netxlite.NewTLSHandshakerStdlib -// thus allowing us to mock this func for testing. -func (tx *Trace) newTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { - if tx.newTLSHandshakerStdlibFn != nil { - return tx.newTLSHandshakerStdlibFn(dl) - } - return netxlite.NewTLSHandshakerStdlib(dl) -} - -// newTLSHandshakerUTLS indirectly calls netxlite.NewTLSHandshakerUTLS -// thus allowing us to mock this func for testing. -func (tx *Trace) newTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - if tx.newTLSHandshakerUTLSFn != nil { - return tx.newTLSHandshakerUTLSFn(dl, id) - } - return netxlite.NewTLSHandshakerUTLS(dl, id) -} - -// newQUICDialerWithoutResolver indirectly calls netxlite.NewQUICDialerWithoutResolver -// thus allowing us to mock this func for testing. -func (tx *Trace) newQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - if tx.newQUICDialerWithoutResolverFn != nil { - return tx.newQUICDialerWithoutResolverFn(listener, dl) - } - return netxlite.NewQUICDialerWithoutResolver(listener, dl) -} - // TimeNow implements model.Trace.TimeNow. func (tx *Trace) TimeNow() time.Time { if tx.timeNowFn != nil { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index e7950f93ad..1b85d97979 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -50,45 +50,16 @@ func TestNewTrace(t *testing.T) { } }) - t.Run("NewStdlibResolverFn is nil", func(t *testing.T) { - if trace.newStdlibResolverFn != nil { - t.Fatal("expected nil NewStdlibResolverFn") + t.Run("Netx is an instance of *netxlite.Netx with a nil .Underlying", func(t *testing.T) { + if trace.Netx == nil { + t.Fatal("expected non-nil .Netx") } - }) - - t.Run("NewParallelUDPResolverFn is nil", func(t *testing.T) { - if trace.newParallelUDPResolverFn != nil { - t.Fatal("expected nil NewParallelUDPResolverFn") + netx, good := trace.Netx.(*netxlite.Netx) + if !good { + t.Fatal("not a *netxlite.Netx") } - }) - - t.Run("NewParallelDNSOverHTTPSResolverFn is nil", func(t *testing.T) { - if trace.newParallelDNSOverHTTPSResolverFn != nil { - t.Fatal("expected nil NewParallelDNSOverHTTPSResolverFn") - } - }) - - t.Run("NewDialerWithoutResolverFn is nil", func(t *testing.T) { - if trace.newDialerWithoutResolverFn != nil { - t.Fatal("expected nil NewDialerWithoutResolverFn") - } - }) - - t.Run("NewTLSHandshakerStdlibFn is nil", func(t *testing.T) { - if trace.newTLSHandshakerStdlibFn != nil { - t.Fatal("expected nil NewTLSHandshakerStdlibFn") - } - }) - - t.Run("newTLShandshakerUTLSFn is nil", func(t *testing.T) { - if trace.newTLSHandshakerUTLSFn != nil { - t.Fatal("expected nil NewTLSHandshakerUTLSfn") - } - }) - - t.Run("NewQUICDialerWithoutResolverFn is nil", func(t *testing.T) { - if trace.newQUICDialerWithoutResolverFn != nil { - t.Fatal("expected nil NewQUICDialerQithoutResolverFn") + if netx.Underlying != nil { + t.Fatal(".Underlying is not nil") } }) @@ -202,34 +173,10 @@ func TestNewTrace(t *testing.T) { } func TestTrace(t *testing.T) { - t.Run("NewStdlibResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newStdlibResolverFn: func(logger model.Logger) model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{}, mockedErr - }, - } - }, - } - resolver := tx.newStdlibResolver(model.DiscardLogger) - ctx := context.Background() - addrs, err := resolver.LookupHost(ctx, "example.com") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) - + t.Run("NewStdlibResolver works as intended", func(t *testing.T) { t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newParallelUDPResolverFn: nil, - } - resolver := tx.newStdlibResolver(model.DiscardLogger) + tx := NewTrace(0, time.Now()) + resolver := tx.NewStdlibResolver(model.DiscardLogger) ctx, cancel := context.WithCancel(context.Background()) cancel() addrs, err := resolver.LookupHost(ctx, "example.com") @@ -242,339 +189,175 @@ func TestTrace(t *testing.T) { }) }) - t.Run("NewParallelUDPResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newParallelUDPResolverFn: func(logger model.Logger, dialer model.Dialer, address string) model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{}, mockedErr - }, - } - }, - } - dialer := &mocks.Dialer{} - resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") - ctx := context.Background() - addrs, err := resolver.LookupHost(ctx, "example.com") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) - - t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newParallelUDPResolverFn: nil, - } - dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) - resolver := tx.newParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") - ctx, cancel := context.WithCancel(context.Background()) - cancel() - addrs, err := resolver.LookupHost(ctx, "example.com") - if err == nil || err.Error() != netxlite.FailureInterrupted { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) + t.Run("NewParallelUDPResolver works as intended", func(t *testing.T) { + tx := NewTrace(0, time.Now()) + dialer := netxlite.NewDialerWithoutResolver(model.DiscardLogger) + resolver := tx.NewParallelUDPResolver(model.DiscardLogger, dialer, "1.1.1.1:53") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + addrs, err := resolver.LookupHost(ctx, "example.com") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } }) - t.Run("NewParallelDNSOverHTTPSResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newParallelDNSOverHTTPSResolverFn: func(logger model.Logger, URL string) model.Resolver { - return &mocks.Resolver{ - MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{}, mockedErr - }, - } - }, - } - resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") - ctx := context.Background() - addrs, err := resolver.LookupHost(ctx, "example.com") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) - - t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newParallelDNSOverHTTPSResolverFn: nil, - } - resolver := tx.newParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") - ctx, cancel := context.WithCancel(context.Background()) - cancel() - addrs, err := resolver.LookupHost(ctx, "example.com") - if err == nil || err.Error() != netxlite.FailureInterrupted { - t.Fatal("unexpected err", err) - } - if len(addrs) != 0 { - t.Fatal("expected array of size 0") - } - }) + t.Run("NewParallelDNSOverHTTPSResolver works as intended", func(t *testing.T) { + tx := NewTrace(0, time.Now()) + resolver := tx.NewParallelDNSOverHTTPSResolver(model.DiscardLogger, "https://dns.google.com") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + addrs, err := resolver.LookupHost(ctx, "example.com") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if len(addrs) != 0 { + t.Fatal("expected array of size 0") + } }) - t.Run("NewDialerWithoutResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newDialerWithoutResolverFn: func(dl model.DebugLogger) model.Dialer { - return &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, mockedErr - }, - } - }, - } - dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) - ctx := context.Background() - conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - tx := &Trace{ - newDialerWithoutResolverFn: nil, - } - dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) - ctx, cancel := context.WithCancel(context.Background()) - cancel() // fail immediately - conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") - if err == nil || err.Error() != netxlite.FailureInterrupted { - t.Fatal("unexpected err", err) - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewDialerWithoutResolver works as intended", func(t *testing.T) { + tx := NewTrace(0, time.Now()) + dialer := tx.NewDialerWithoutResolver(model.DiscardLogger) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // fail immediately + conn, err := dialer.DialContext(ctx, "tcp", "1.1.1.1:443") + if err == nil || err.Error() != netxlite.FailureInterrupted { + t.Fatal("unexpected err", err) + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - t.Run("NewTLSHandshakerStdlibFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerStdlibFn: func(dl model.DebugLogger) model.TLSHandshaker { - return &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, mockedErr - }, - } - }, - } - thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerStdlibFn: nil, - } - thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) - tcpConn := &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - MockRemoteAddr: func() net.Addr { - return &mocks.Addr{ - MockNetwork: func() string { - return "tcp" - }, - MockString: func() string { - return "1.1.1.1:443" - }, - } - }, - MockWrite: func(b []byte) (int, error) { - return 0, mockedErr - }, - MockClose: func() error { - return nil - }, - } - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewTLSHandshakerStdlib works as intended", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := NewTrace(0, time.Now()) + thx := tx.NewTLSHandshakerStdlib(model.DiscardLogger) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - t.Run("NewTLSHandshakerUTLSFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerUTLSFn: func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - return &mocks.TLSHandshaker{ - MockHandshake: func(ctx context.Context, conn net.Conn, config *tls.Config) (net.Conn, tls.ConnectionState, error) { - return nil, tls.ConnectionState{}, mockedErr - }, - } - }, - } - thx := tx.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, &mocks.Conn{}, &tls.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newTLSHandshakerStdlibFn: nil, - } - thx := tx.newTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) - tcpConn := &mocks.Conn{ - MockSetDeadline: func(t time.Time) error { - return nil - }, - MockRemoteAddr: func() net.Addr { - return &mocks.Addr{ - MockNetwork: func() string { - return "tcp" - }, - MockString: func() string { - return "1.1.1.1:443" - }, - } - }, - MockWrite: func(b []byte) (int, error) { - return 0, mockedErr - }, - MockClose: func() error { - return nil - }, - } - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - ctx := context.Background() - conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if !reflect.ValueOf(state).IsZero() { - t.Fatal("state is not a zero value") - } - if conn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewTLSHandshakerUTLS works as intended", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := NewTrace(0, time.Now()) + thx := tx.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) + tcpConn := &mocks.Conn{ + MockSetDeadline: func(t time.Time) error { + return nil + }, + MockRemoteAddr: func() net.Addr { + return &mocks.Addr{ + MockNetwork: func() string { + return "tcp" + }, + MockString: func() string { + return "1.1.1.1:443" + }, + } + }, + MockWrite: func(b []byte) (int, error) { + return 0, mockedErr + }, + MockClose: func() error { + return nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + ctx := context.Background() + conn, state, err := thx.Handshake(ctx, tcpConn, tlsConfig) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if !reflect.ValueOf(state).IsZero() { + t.Fatal("state is not a zero value") + } + if conn != nil { + t.Fatal("expected nil conn") + } }) - t.Run("NewQUICDialerWithoutResolverFn works as intended", func(t *testing.T) { - t.Run("when not nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newQUICDialerWithoutResolverFn: func(listener model.UDPListener, dl model.DebugLogger) model.QUICDialer { - return &mocks.QUICDialer{ - MockDialContext: func(ctx context.Context, address string, - tlsConfig *tls.Config, quicConfig *quic.Config) (quic.EarlyConnection, error) { - return nil, mockedErr - }, - } - }, - } - qdx := tx.newQUICDialerWithoutResolver(&mocks.UDPListener{}, model.DiscardLogger) - ctx := context.Background() - qconn, err := qdx.DialContext(ctx, "1.1.1.1:443", &tls.Config{}, &quic.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if qconn != nil { - t.Fatal("expected nil conn") - } - }) - - t.Run("when nil", func(t *testing.T) { - mockedErr := errors.New("mocked") - tx := &Trace{ - newQUICDialerWithoutResolverFn: nil, // explicit - } - pconn := &mocks.UDPLikeConn{ - MockLocalAddr: func() net.Addr { - return &net.UDPAddr{ - // quic-go does not allow the use of the same net.PacketConn for multiple "Dial" - // calls (unless a quic.Transport is used), so we have to make sure to mock local - // addresses with different ports, as tests run in parallel. - Port: 0, - } - }, - MockRemoteAddr: func() net.Addr { - return &net.UDPAddr{ - Port: 0, - } - }, - MockSyscallConn: func() (syscall.RawConn, error) { - return nil, mockedErr - }, - MockClose: func() error { - return nil - }, - MockSetReadBuffer: func(n int) error { - return nil - }, - } - listener := &mocks.UDPListener{ - MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { - return pconn, nil - }, - } - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - } - dialer := tx.newQUICDialerWithoutResolver(listener, model.DiscardLogger) - ctx := context.Background() - qconn, err := dialer.DialContext(ctx, "1.1.1.1:443", tlsConfig, &quic.Config{}) - if !errors.Is(err, mockedErr) { - t.Fatal("unexpected err", err) - } - if qconn != nil { - t.Fatal("expected nil conn") - } - }) + t.Run("NewQUICDialerWithoutResolver works as intended", func(t *testing.T) { + mockedErr := errors.New("mocked") + tx := NewTrace(0, time.Now()) + pconn := &mocks.UDPLikeConn{ + MockLocalAddr: func() net.Addr { + return &net.UDPAddr{ + // quic-go does not allow the use of the same net.PacketConn for multiple "Dial" + // calls (unless a quic.Transport is used), so we have to make sure to mock local + // addresses with different ports, as tests run in parallel. + Port: 0, + } + }, + MockRemoteAddr: func() net.Addr { + return &net.UDPAddr{ + Port: 0, + } + }, + MockSyscallConn: func() (syscall.RawConn, error) { + return nil, mockedErr + }, + MockClose: func() error { + return nil + }, + MockSetReadBuffer: func(n int) error { + return nil + }, + } + listener := &mocks.UDPListener{ + MockListen: func(addr *net.UDPAddr) (model.UDPLikeConn, error) { + return pconn, nil + }, + } + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + dialer := tx.NewQUICDialerWithoutResolver(listener, model.DiscardLogger) + ctx := context.Background() + qconn, err := dialer.DialContext(ctx, "1.1.1.1:443", tlsConfig, &quic.Config{}) + if !errors.Is(err, mockedErr) { + t.Fatal("unexpected err", err) + } + if qconn != nil { + t.Fatal("expected nil conn") + } }) t.Run("TimeNowFn works as intended", func(t *testing.T) { diff --git a/internal/measurexlite/utls.go b/internal/measurexlite/utls.go index 747c097972..d00016006b 100644 --- a/internal/measurexlite/utls.go +++ b/internal/measurexlite/utls.go @@ -13,7 +13,7 @@ import ( // except that it returns a model.TLSHandshaker that uses this trace. func (tx *Trace) NewTLSHandshakerUTLS(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { return &tlsHandshakerTrace{ - thx: tx.newTLSHandshakerUTLS(dl, id), + thx: tx.Netx.NewTLSHandshakerUTLS(dl, id), tx: tx, } } diff --git a/internal/measurexlite/utls_test.go b/internal/measurexlite/utls_test.go index 2dccd89f0a..959aff0657 100644 --- a/internal/measurexlite/utls_test.go +++ b/internal/measurexlite/utls_test.go @@ -14,8 +14,10 @@ func TestNewTLSHandshakerUTLS(t *testing.T) { underlying := &mocks.TLSHandshaker{} zeroTime := time.Now() trace := NewTrace(0, zeroTime) - trace.newTLSHandshakerUTLSFn = func(dl model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { - return underlying + trace.Netx = &mocks.MeasuringNetwork{ + MockNewTLSHandshakerUTLS: func(logger model.DebugLogger, id *utls.ClientHelloID) model.TLSHandshaker { + return underlying + }, } thx := trace.NewTLSHandshakerUTLS(model.DiscardLogger, &utls.HelloGolang) thxt := thx.(*tlsHandshakerTrace)