From 39ab3d27dec8d203981e39e296695623d1c52129 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 24 Jan 2024 11:50:20 +0100 Subject: [PATCH] refactor(netemx,oohelperd): use oohelperd.NewHandler constructor (#1468) In https://github.com/ooni/probe-cli/pull/1467 we made the netemx constructor for oohelperd.Handler equivalent to oohelperd.NewHandler. So, now it becomes possible to always use oohelperd.NewHandler. While there, notice that we can make all the Handler fields private because there's no need to share them anymore, so do that. Having done this, we are now sure we have the same `oohelperd` behavior for QA and production. In turn, with this guarantee, we can write QA tests that ensure we're correctly dealing with 127.0.0.1. The reference issue is https://github.com/ooni/probe/issues/1517. --- internal/netemx/oohelperd.go | 32 ---------- internal/oohelperd/handler.go | 93 ++++++++++++++++-------------- internal/oohelperd/handler_test.go | 4 +- internal/oohelperd/measure.go | 20 +++---- 4 files changed, 61 insertions(+), 88 deletions(-) diff --git a/internal/netemx/oohelperd.go b/internal/netemx/oohelperd.go index 793a7c19b2..6bc8a8da7d 100644 --- a/internal/netemx/oohelperd.go +++ b/internal/netemx/oohelperd.go @@ -7,7 +7,6 @@ import ( "github.com/apex/log" "github.com/ooni/netem" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/oohelperd" ) @@ -26,36 +25,5 @@ func (f *OOHelperDFactory) NewHandler(env NetStackServerFactoryEnv, unet *netem. Logger: log.Log, } handler := oohelperd.NewHandler(logger, netx) - - handler.NewDialer = func(logger model.Logger) model.Dialer { - return netx.NewDialerWithResolver(logger, netx.NewStdlibResolver(logger)) - } - - handler.NewQUICDialer = func(logger model.Logger) model.QUICDialer { - return netx.NewQUICDialerWithResolver( - netx.NewUDPListener(), - logger, - netx.NewStdlibResolver(logger), - ) - } - - handler.NewResolver = func(logger model.Logger) model.Resolver { - return netx.NewStdlibResolver(logger) - } - - handler.NewHTTPClient = func(logger model.Logger) model.HTTPClient { - return oohelperd.NewHTTPClientWithTransportFactory( - netx, logger, - netxlite.NewHTTPTransportWithResolver, - ) - } - - handler.NewHTTP3Client = func(logger model.Logger) model.HTTPClient { - return oohelperd.NewHTTPClientWithTransportFactory( - netx, logger, - netxlite.NewHTTP3TransportWithResolver, - ) - } - return handler } diff --git a/internal/oohelperd/handler.go b/internal/oohelperd/handler.go index fe052f2255..dd76c42cdc 100644 --- a/internal/oohelperd/handler.go +++ b/internal/oohelperd/handler.go @@ -22,47 +22,49 @@ import ( "golang.org/x/net/publicsuffix" ) -// MaxAcceptableBodySize is the maximum acceptable body size for incoming +// maxAcceptableBodySize is the maximum acceptable body size for incoming // API requests as well as when we're measuring webpages. -const MaxAcceptableBodySize = 1 << 24 +const maxAcceptableBodySize = 1 << 24 // Handler is an [http.Handler] implementing the Web // Connectivity test helper HTTP API. +// +// The zero value is invalid; construct using [NewHandler]. type Handler struct { - // BaseLogger is the MANDATORY logger to use. - BaseLogger model.Logger + // baseLogger is the MANDATORY logger to use. + baseLogger model.Logger - // CountRequests is the MANDATORY count of the number of + // countRequests is the MANDATORY count of the number of // requests that are currently in flight. - CountRequests *atomic.Int64 + countRequests *atomic.Int64 - // Indexer is the MANDATORY atomic integer used to assign an index to requests. - Indexer *atomic.Int64 + // indexer is the MANDATORY atomic integer used to assign an index to requests. + indexer *atomic.Int64 - // MaxAcceptableBody is the MANDATORY maximum acceptable response body. - MaxAcceptableBody int64 + // maxAcceptableBody is the MANDATORY maximum acceptable response body. + maxAcceptableBody int64 - // Measure is the MANDATORY function that the handler should call + // measure is the MANDATORY function that the handler should call // for producing a response for a valid incoming request. - Measure func(ctx context.Context, config *Handler, creq *model.THRequest) (*model.THResponse, error) + measure func(ctx context.Context, config *Handler, creq *model.THRequest) (*model.THResponse, error) - // NewDialer is the MANDATORY factory to create a new Dialer. - NewDialer func(model.Logger) model.Dialer + // newDialer is the MANDATORY factory to create a new Dialer. + newDialer func(model.Logger) model.Dialer - // NewHTTPClient is the MANDATORY factory to create a new HTTPClient. - NewHTTPClient func(model.Logger) model.HTTPClient + // newHTTPClient is the MANDATORY factory to create a new HTTPClient. + newHTTPClient func(model.Logger) model.HTTPClient - // NewHTTP3Client is the MANDATORY factory to create a new HTTP3Client. - NewHTTP3Client func(model.Logger) model.HTTPClient + // newHTTP3Client is the MANDATORY factory to create a new HTTP3Client. + newHTTP3Client func(model.Logger) model.HTTPClient - // NewQUICDialer is the MANDATORY factory to create a new QUICDialer. - NewQUICDialer func(model.Logger) model.QUICDialer + // newQUICDialer is the MANDATORY factory to create a new QUICDialer. + newQUICDialer func(model.Logger) model.QUICDialer - // NewResolver is the MANDATORY factory for creating a new resolver. - NewResolver func(model.Logger) model.Resolver + // newResolver is the MANDATORY factory for creating a new resolver. + newResolver func(model.Logger) model.Resolver - // NewTLSHandshaker is the MANDATORY factory for creating a new TLS handshaker. - NewTLSHandshaker func(model.Logger) model.TLSHandshaker + // newTLSHandshaker is the MANDATORY factory for creating a new TLS handshaker. + newTLSHandshaker func(model.Logger) model.TLSHandshaker } var _ http.Handler = &Handler{} @@ -70,41 +72,44 @@ var _ http.Handler = &Handler{} // NewHandler constructs the [handler]. func NewHandler(logger model.Logger, netx *netxlite.Netx) *Handler { return &Handler{ - BaseLogger: logger, - CountRequests: &atomic.Int64{}, - Indexer: &atomic.Int64{}, - MaxAcceptableBody: MaxAcceptableBodySize, - Measure: measure, + baseLogger: logger, + countRequests: &atomic.Int64{}, + indexer: &atomic.Int64{}, + maxAcceptableBody: maxAcceptableBodySize, + measure: measure, - NewHTTPClient: func(logger model.Logger) model.HTTPClient { + newHTTPClient: func(logger model.Logger) model.HTTPClient { // TODO(https://github.com/ooni/probe/issues/2534): the NewHTTPTransportWithResolver has QUIRKS and // we should evaluate whether we can avoid using it here - return NewHTTPClientWithTransportFactory( + return newHTTPClientWithTransportFactory( netx, logger, netxlite.NewHTTPTransportWithResolver, ) }, - NewHTTP3Client: func(logger model.Logger) model.HTTPClient { - return NewHTTPClientWithTransportFactory( + newHTTP3Client: func(logger model.Logger) model.HTTPClient { + return newHTTPClientWithTransportFactory( netx, logger, netxlite.NewHTTP3TransportWithResolver, ) }, - NewDialer: func(logger model.Logger) model.Dialer { + newDialer: func(logger model.Logger) model.Dialer { return netx.NewDialerWithoutResolver(logger) }, - NewQUICDialer: func(logger model.Logger) model.QUICDialer { + + newQUICDialer: func(logger model.Logger) model.QUICDialer { return netx.NewQUICDialerWithoutResolver( netx.NewUDPListener(), logger, ) }, - NewResolver: func(logger model.Logger) model.Resolver { + + newResolver: func(logger model.Logger) model.Resolver { return newResolver(logger, netx) }, - NewTLSHandshaker: func(logger model.Logger) model.TLSHandshaker { + + newTLSHandshaker: func(logger model.Logger) model.TLSHandshaker { return netx.NewTLSHandshakerStdlib(logger) }, } @@ -151,16 +156,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { } // protect against too many requests in flight - if handlerShouldThrottleClient(h.CountRequests.Load(), req.Header.Get("user-agent")) { + if handlerShouldThrottleClient(h.countRequests.Load(), req.Header.Get("user-agent")) { metricRequestsCount.WithLabelValues("503", "service_unavailable").Inc() w.WriteHeader(503) return } - h.CountRequests.Add(1) - defer h.CountRequests.Add(-1) + h.countRequests.Add(1) + defer h.countRequests.Add(-1) // read and parse request body - reader := io.LimitReader(req.Body, h.MaxAcceptableBody) + reader := io.LimitReader(req.Body, h.maxAcceptableBody) data, err := netxlite.ReadAllContext(req.Context(), reader) if err != nil { metricRequestsCount.WithLabelValues("400", "request_body_too_large").Inc() @@ -176,7 +181,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { // measure the given input started := time.Now() - cresp, err := h.Measure(req.Context(), h, &creq) + cresp, err := h.measure(req.Context(), h, &creq) elapsed := time.Since(started) // track the time required to produce a response @@ -219,9 +224,9 @@ func newCookieJar() *cookiejar.Jar { })) } -// NewHTTPClientWithTransportFactory creates a new HTTP client +// newHTTPClientWithTransportFactory creates a new HTTP client // using the given [model.HTTPTransport] factory. -func NewHTTPClientWithTransportFactory( +func newHTTPClientWithTransportFactory( netx *netxlite.Netx, logger model.Logger, txpFactory func(*netxlite.Netx, model.DebugLogger, model.Resolver) model.HTTPTransport, ) model.HTTPClient { diff --git a/internal/oohelperd/handler_test.go b/internal/oohelperd/handler_test.go index 64578ffafd..8285e6faae 100644 --- a/internal/oohelperd/handler_test.go +++ b/internal/oohelperd/handler_test.go @@ -190,12 +190,12 @@ func TestHandlerWorkingAsIntended(t *testing.T) { // create handler and possibly override .Measure handler := NewHandler(log.Log, &netxlite.Netx{}) if expect.measureFn != nil { - handler.Measure = expect.measureFn + handler.measure = expect.measureFn } // configure the CountRequests field if needed if expect.initialCountRequests > 0 { - handler.CountRequests.Add(expect.initialCountRequests) // 0 + value = value :-) + handler.countRequests.Add(expect.initialCountRequests) // 0 + value = value :-) } // create request diff --git a/internal/oohelperd/measure.go b/internal/oohelperd/measure.go index 1d041f0e12..4ed932778a 100644 --- a/internal/oohelperd/measure.go +++ b/internal/oohelperd/measure.go @@ -28,8 +28,8 @@ type ( func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResponse, error) { // create indexed logger logger := &logx.PrefixLogger{ - Prefix: fmt.Sprintf("<#%d> ", config.Indexer.Add(1)), - Logger: config.BaseLogger, + Prefix: fmt.Sprintf("<#%d> ", config.indexer.Add(1)), + Logger: config.baseLogger, } // parse input for correctness @@ -47,7 +47,7 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp go dnsDo(ctx, &dnsConfig{ Domain: URL.Hostname(), Logger: logger, - NewResolver: config.NewResolver, + NewResolver: config.newResolver, Out: dnsch, Wg: wg, }) @@ -91,8 +91,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp EnableTLS: endpoint.TLS, Endpoint: endpoint.Epnt, Logger: logger, - NewDialer: config.NewDialer, - NewTSLHandshaker: config.NewTLSHandshaker, + NewDialer: config.newDialer, + NewTSLHandshaker: config.newTLSHandshaker, URLHostname: URL.Hostname(), Out: tcpconnch, Wg: wg, @@ -105,8 +105,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp go httpDo(ctx, &httpConfig{ Headers: creq.HTTPRequestHeaders, Logger: logger, - MaxAcceptableBody: config.MaxAcceptableBody, - NewClient: config.NewHTTPClient, + MaxAcceptableBody: config.maxAcceptableBody, + NewClient: config.newHTTPClient, Out: httpch, URL: creq.HTTPRequest, Wg: wg, @@ -133,7 +133,7 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp Address: endpoint.Addr, Endpoint: endpoint.Epnt, Logger: logger, - NewQUICDialer: config.NewQUICDialer, + NewQUICDialer: config.newQUICDialer, URLHostname: URL.Hostname(), Out: quicconnch, Wg: wg, @@ -147,8 +147,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp go httpDo(ctx, &httpConfig{ Headers: creq.HTTPRequestHeaders, Logger: logger, - MaxAcceptableBody: config.MaxAcceptableBody, - NewClient: config.NewHTTP3Client, + MaxAcceptableBody: config.maxAcceptableBody, + NewClient: config.newHTTP3Client, Out: http3ch, URL: "https://" + cresp.HTTPRequest.DiscoveredH3Endpoint, Wg: wg,