Skip to content

Commit

Permalink
refactor(netemx,oohelperd): use oohelperd.NewHandler constructor (#1468)
Browse files Browse the repository at this point in the history
In #1467 we made the netemx
constructor for oohelperd.Handler equivalent to oohelperd.NewHandler.

So, now it becomes possible to always use oohelperd.NewHandler.

While there, notice that we can make all the Handler fields private
because there's no need to share them anymore, so do that.

Having done this, we are now sure we have the same `oohelperd` behavior
for QA and production.

In turn, with this guarantee, we can write QA tests that ensure we're
correctly dealing with 127.0.0.1.

The reference issue is ooni/probe#1517.
  • Loading branch information
bassosimone authored Jan 24, 2024
1 parent 8331a30 commit 39ab3d2
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 88 deletions.
32 changes: 0 additions & 32 deletions internal/netemx/oohelperd.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"github.com/apex/log"
"github.com/ooni/netem"
"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/ooni/probe-cli/v3/internal/oohelperd"
)
Expand All @@ -26,36 +25,5 @@ func (f *OOHelperDFactory) NewHandler(env NetStackServerFactoryEnv, unet *netem.
Logger: log.Log,
}
handler := oohelperd.NewHandler(logger, netx)

handler.NewDialer = func(logger model.Logger) model.Dialer {
return netx.NewDialerWithResolver(logger, netx.NewStdlibResolver(logger))
}

handler.NewQUICDialer = func(logger model.Logger) model.QUICDialer {
return netx.NewQUICDialerWithResolver(
netx.NewUDPListener(),
logger,
netx.NewStdlibResolver(logger),
)
}

handler.NewResolver = func(logger model.Logger) model.Resolver {
return netx.NewStdlibResolver(logger)
}

handler.NewHTTPClient = func(logger model.Logger) model.HTTPClient {
return oohelperd.NewHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTPTransportWithResolver,
)
}

handler.NewHTTP3Client = func(logger model.Logger) model.HTTPClient {
return oohelperd.NewHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTP3TransportWithResolver,
)
}

return handler
}
93 changes: 49 additions & 44 deletions internal/oohelperd/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,89 +22,94 @@ import (
"golang.org/x/net/publicsuffix"
)

// MaxAcceptableBodySize is the maximum acceptable body size for incoming
// maxAcceptableBodySize is the maximum acceptable body size for incoming
// API requests as well as when we're measuring webpages.
const MaxAcceptableBodySize = 1 << 24
const maxAcceptableBodySize = 1 << 24

// Handler is an [http.Handler] implementing the Web
// Connectivity test helper HTTP API.
//
// The zero value is invalid; construct using [NewHandler].
type Handler struct {
// BaseLogger is the MANDATORY logger to use.
BaseLogger model.Logger
// baseLogger is the MANDATORY logger to use.
baseLogger model.Logger

// CountRequests is the MANDATORY count of the number of
// countRequests is the MANDATORY count of the number of
// requests that are currently in flight.
CountRequests *atomic.Int64
countRequests *atomic.Int64

// Indexer is the MANDATORY atomic integer used to assign an index to requests.
Indexer *atomic.Int64
// indexer is the MANDATORY atomic integer used to assign an index to requests.
indexer *atomic.Int64

// MaxAcceptableBody is the MANDATORY maximum acceptable response body.
MaxAcceptableBody int64
// maxAcceptableBody is the MANDATORY maximum acceptable response body.
maxAcceptableBody int64

// Measure is the MANDATORY function that the handler should call
// measure is the MANDATORY function that the handler should call
// for producing a response for a valid incoming request.
Measure func(ctx context.Context, config *Handler, creq *model.THRequest) (*model.THResponse, error)
measure func(ctx context.Context, config *Handler, creq *model.THRequest) (*model.THResponse, error)

// NewDialer is the MANDATORY factory to create a new Dialer.
NewDialer func(model.Logger) model.Dialer
// newDialer is the MANDATORY factory to create a new Dialer.
newDialer func(model.Logger) model.Dialer

// NewHTTPClient is the MANDATORY factory to create a new HTTPClient.
NewHTTPClient func(model.Logger) model.HTTPClient
// newHTTPClient is the MANDATORY factory to create a new HTTPClient.
newHTTPClient func(model.Logger) model.HTTPClient

// NewHTTP3Client is the MANDATORY factory to create a new HTTP3Client.
NewHTTP3Client func(model.Logger) model.HTTPClient
// newHTTP3Client is the MANDATORY factory to create a new HTTP3Client.
newHTTP3Client func(model.Logger) model.HTTPClient

// NewQUICDialer is the MANDATORY factory to create a new QUICDialer.
NewQUICDialer func(model.Logger) model.QUICDialer
// newQUICDialer is the MANDATORY factory to create a new QUICDialer.
newQUICDialer func(model.Logger) model.QUICDialer

// NewResolver is the MANDATORY factory for creating a new resolver.
NewResolver func(model.Logger) model.Resolver
// newResolver is the MANDATORY factory for creating a new resolver.
newResolver func(model.Logger) model.Resolver

// NewTLSHandshaker is the MANDATORY factory for creating a new TLS handshaker.
NewTLSHandshaker func(model.Logger) model.TLSHandshaker
// newTLSHandshaker is the MANDATORY factory for creating a new TLS handshaker.
newTLSHandshaker func(model.Logger) model.TLSHandshaker
}

var _ http.Handler = &Handler{}

// NewHandler constructs the [handler].
func NewHandler(logger model.Logger, netx *netxlite.Netx) *Handler {
return &Handler{
BaseLogger: logger,
CountRequests: &atomic.Int64{},
Indexer: &atomic.Int64{},
MaxAcceptableBody: MaxAcceptableBodySize,
Measure: measure,
baseLogger: logger,
countRequests: &atomic.Int64{},
indexer: &atomic.Int64{},
maxAcceptableBody: maxAcceptableBodySize,
measure: measure,

NewHTTPClient: func(logger model.Logger) model.HTTPClient {
newHTTPClient: func(logger model.Logger) model.HTTPClient {
// TODO(https://github.com/ooni/probe/issues/2534): the NewHTTPTransportWithResolver has QUIRKS and
// we should evaluate whether we can avoid using it here
return NewHTTPClientWithTransportFactory(
return newHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTPTransportWithResolver,
)
},

NewHTTP3Client: func(logger model.Logger) model.HTTPClient {
return NewHTTPClientWithTransportFactory(
newHTTP3Client: func(logger model.Logger) model.HTTPClient {
return newHTTPClientWithTransportFactory(
netx, logger,
netxlite.NewHTTP3TransportWithResolver,
)
},

NewDialer: func(logger model.Logger) model.Dialer {
newDialer: func(logger model.Logger) model.Dialer {
return netx.NewDialerWithoutResolver(logger)
},
NewQUICDialer: func(logger model.Logger) model.QUICDialer {

newQUICDialer: func(logger model.Logger) model.QUICDialer {
return netx.NewQUICDialerWithoutResolver(
netx.NewUDPListener(),
logger,
)
},
NewResolver: func(logger model.Logger) model.Resolver {

newResolver: func(logger model.Logger) model.Resolver {
return newResolver(logger, netx)
},
NewTLSHandshaker: func(logger model.Logger) model.TLSHandshaker {

newTLSHandshaker: func(logger model.Logger) model.TLSHandshaker {
return netx.NewTLSHandshakerStdlib(logger)
},
}
Expand Down Expand Up @@ -151,16 +156,16 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}

// protect against too many requests in flight
if handlerShouldThrottleClient(h.CountRequests.Load(), req.Header.Get("user-agent")) {
if handlerShouldThrottleClient(h.countRequests.Load(), req.Header.Get("user-agent")) {
metricRequestsCount.WithLabelValues("503", "service_unavailable").Inc()
w.WriteHeader(503)
return
}
h.CountRequests.Add(1)
defer h.CountRequests.Add(-1)
h.countRequests.Add(1)
defer h.countRequests.Add(-1)

// read and parse request body
reader := io.LimitReader(req.Body, h.MaxAcceptableBody)
reader := io.LimitReader(req.Body, h.maxAcceptableBody)
data, err := netxlite.ReadAllContext(req.Context(), reader)
if err != nil {
metricRequestsCount.WithLabelValues("400", "request_body_too_large").Inc()
Expand All @@ -176,7 +181,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {

// measure the given input
started := time.Now()
cresp, err := h.Measure(req.Context(), h, &creq)
cresp, err := h.measure(req.Context(), h, &creq)
elapsed := time.Since(started)

// track the time required to produce a response
Expand Down Expand Up @@ -219,9 +224,9 @@ func newCookieJar() *cookiejar.Jar {
}))
}

// NewHTTPClientWithTransportFactory creates a new HTTP client
// newHTTPClientWithTransportFactory creates a new HTTP client
// using the given [model.HTTPTransport] factory.
func NewHTTPClientWithTransportFactory(
func newHTTPClientWithTransportFactory(
netx *netxlite.Netx, logger model.Logger,
txpFactory func(*netxlite.Netx, model.DebugLogger, model.Resolver) model.HTTPTransport,
) model.HTTPClient {
Expand Down
4 changes: 2 additions & 2 deletions internal/oohelperd/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ func TestHandlerWorkingAsIntended(t *testing.T) {
// create handler and possibly override .Measure
handler := NewHandler(log.Log, &netxlite.Netx{})
if expect.measureFn != nil {
handler.Measure = expect.measureFn
handler.measure = expect.measureFn
}

// configure the CountRequests field if needed
if expect.initialCountRequests > 0 {
handler.CountRequests.Add(expect.initialCountRequests) // 0 + value = value :-)
handler.countRequests.Add(expect.initialCountRequests) // 0 + value = value :-)
}

// create request
Expand Down
20 changes: 10 additions & 10 deletions internal/oohelperd/measure.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ type (
func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResponse, error) {
// create indexed logger
logger := &logx.PrefixLogger{
Prefix: fmt.Sprintf("<#%d> ", config.Indexer.Add(1)),
Logger: config.BaseLogger,
Prefix: fmt.Sprintf("<#%d> ", config.indexer.Add(1)),
Logger: config.baseLogger,
}

// parse input for correctness
Expand All @@ -47,7 +47,7 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
go dnsDo(ctx, &dnsConfig{
Domain: URL.Hostname(),
Logger: logger,
NewResolver: config.NewResolver,
NewResolver: config.newResolver,
Out: dnsch,
Wg: wg,
})
Expand Down Expand Up @@ -91,8 +91,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
EnableTLS: endpoint.TLS,
Endpoint: endpoint.Epnt,
Logger: logger,
NewDialer: config.NewDialer,
NewTSLHandshaker: config.NewTLSHandshaker,
NewDialer: config.newDialer,
NewTSLHandshaker: config.newTLSHandshaker,
URLHostname: URL.Hostname(),
Out: tcpconnch,
Wg: wg,
Expand All @@ -105,8 +105,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
go httpDo(ctx, &httpConfig{
Headers: creq.HTTPRequestHeaders,
Logger: logger,
MaxAcceptableBody: config.MaxAcceptableBody,
NewClient: config.NewHTTPClient,
MaxAcceptableBody: config.maxAcceptableBody,
NewClient: config.newHTTPClient,
Out: httpch,
URL: creq.HTTPRequest,
Wg: wg,
Expand All @@ -133,7 +133,7 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
Address: endpoint.Addr,
Endpoint: endpoint.Epnt,
Logger: logger,
NewQUICDialer: config.NewQUICDialer,
NewQUICDialer: config.newQUICDialer,
URLHostname: URL.Hostname(),
Out: quicconnch,
Wg: wg,
Expand All @@ -147,8 +147,8 @@ func measure(ctx context.Context, config *Handler, creq *ctrlRequest) (*ctrlResp
go httpDo(ctx, &httpConfig{
Headers: creq.HTTPRequestHeaders,
Logger: logger,
MaxAcceptableBody: config.MaxAcceptableBody,
NewClient: config.NewHTTP3Client,
MaxAcceptableBody: config.maxAcceptableBody,
NewClient: config.newHTTP3Client,
Out: http3ch,
URL: "https://" + cresp.HTTPRequest.DiscoveredH3Endpoint,
Wg: wg,
Expand Down

0 comments on commit 39ab3d2

Please sign in to comment.