diff --git a/internal/netxlite/http.go b/internal/netxlite/http.go index 1c57da76af..38c736d865 100644 --- a/internal/netxlite/http.go +++ b/internal/netxlite/http.go @@ -5,106 +5,12 @@ package netxlite // import ( - "context" - "errors" - "net" "net/http" - "time" oohttp "github.com/ooni/oohttp" "github.com/ooni/probe-cli/v3/internal/model" ) -// httpTransportErrWrapper is an HTTPTransport with error wrapping. -type httpTransportErrWrapper struct { - HTTPTransport model.HTTPTransport -} - -var _ model.HTTPTransport = &httpTransportErrWrapper{} - -func (txp *httpTransportErrWrapper) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := txp.HTTPTransport.RoundTrip(req) - if err != nil { - return nil, NewTopLevelGenericErrWrapper(err) - } - return resp, nil -} - -func (txp *httpTransportErrWrapper) CloseIdleConnections() { - txp.HTTPTransport.CloseIdleConnections() -} - -func (txp *httpTransportErrWrapper) Network() string { - return txp.HTTPTransport.Network() -} - -// httpTransportLogger is an HTTPTransport with logging. -type httpTransportLogger struct { - // HTTPTransport is the underlying HTTP transport. - HTTPTransport model.HTTPTransport - - // Logger is the underlying logger. - Logger model.DebugLogger -} - -var _ model.HTTPTransport = &httpTransportLogger{} - -func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) { - txp.Logger.Debugf("> %s %s", req.Method, req.URL.String()) - for key, values := range req.Header { - for _, value := range values { - txp.Logger.Debugf("> %s: %s", key, value) - } - } - txp.Logger.Debug(">") - resp, err := txp.HTTPTransport.RoundTrip(req) - if err != nil { - txp.Logger.Debugf("< %s", err) - return nil, err - } - txp.Logger.Debugf("< %d", resp.StatusCode) - for key, values := range resp.Header { - for _, value := range values { - txp.Logger.Debugf("< %s: %s", key, value) - } - } - txp.Logger.Debug("<") - return resp, nil -} - -func (txp *httpTransportLogger) CloseIdleConnections() { - txp.HTTPTransport.CloseIdleConnections() -} - -func (txp *httpTransportLogger) Network() string { - return txp.HTTPTransport.Network() -} - -// httpTransportConnectionsCloser is an HTTPTransport that -// correctly forwards CloseIdleConnections calls. -type httpTransportConnectionsCloser struct { - HTTPTransport model.HTTPTransport - Dialer model.Dialer - TLSDialer model.TLSDialer -} - -var _ model.HTTPTransport = &httpTransportConnectionsCloser{} - -func (txp *httpTransportConnectionsCloser) RoundTrip(req *http.Request) (*http.Response, error) { - return txp.HTTPTransport.RoundTrip(req) -} - -func (txp *httpTransportConnectionsCloser) Network() string { - return txp.HTTPTransport.Network() -} - -// CloseIdleConnections forwards the CloseIdleConnections calls. -func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { - txp.HTTPTransport.CloseIdleConnections() - txp.Dialer.CloseIdleConnections() - txp.TLSDialer.CloseIdleConnections() -} - // NewHTTPTransportWithResolver creates a new HTTP transport using // the stdlib for everything but the given resolver. func NewHTTPTransportWithResolver(logger model.DebugLogger, reso model.Resolver) model.HTTPTransport { @@ -185,26 +91,6 @@ func newOOHTTPBaseTransport(dialer model.Dialer, tlsDialer model.TLSDialer) mode } } -// stdlibTransport wraps oohttp.StdlibTransport to add .Network() -type httpTransportStdlib struct { - StdlibTransport *oohttp.StdlibTransport -} - -var _ model.HTTPTransport = &httpTransportStdlib{} - -func (txp *httpTransportStdlib) CloseIdleConnections() { - txp.StdlibTransport.CloseIdleConnections() -} - -func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) { - return txp.StdlibTransport.RoundTrip(req) -} - -// Network implements HTTPTransport.Network. -func (txp *httpTransportStdlib) Network() string { - return "tcp" -} - // WrapHTTPTransport creates an HTTPTransport using the given logger // and guarantees that returned errors are wrapped. // @@ -216,105 +102,6 @@ func WrapHTTPTransport(logger model.DebugLogger, txp model.HTTPTransport) model. } } -// httpDialerWithReadTimeout enforces a read timeout for all HTTP -// connections. See https://github.com/ooni/probe/issues/1609. -type httpDialerWithReadTimeout struct { - Dialer model.Dialer -} - -var _ model.Dialer = &httpDialerWithReadTimeout{} - -func (d *httpDialerWithReadTimeout) CloseIdleConnections() { - d.Dialer.CloseIdleConnections() -} - -// DialContext implements Dialer.DialContext. -func (d *httpDialerWithReadTimeout) DialContext( - ctx context.Context, network, address string) (net.Conn, error) { - conn, err := d.Dialer.DialContext(ctx, network, address) - if err != nil { - return nil, err - } - return &httpConnWithReadTimeout{conn}, nil -} - -// httpTLSDialerWithReadTimeout enforces a read timeout for all HTTP -// connections. See https://github.com/ooni/probe/issues/1609. -type httpTLSDialerWithReadTimeout struct { - TLSDialer model.TLSDialer -} - -var _ model.TLSDialer = &httpTLSDialerWithReadTimeout{} - -func (d *httpTLSDialerWithReadTimeout) CloseIdleConnections() { - d.TLSDialer.CloseIdleConnections() -} - -// ErrNotTLSConn occur when an interface accepts a net.Conn but -// internally needs a TLSConn and you pass a net.Conn that doesn't -// implement TLSConn to such an interface. -var ErrNotTLSConn = errors.New("not a TLSConn") - -// DialTLSContext implements TLSDialer's DialTLSContext. -func (d *httpTLSDialerWithReadTimeout) DialTLSContext( - ctx context.Context, network, address string) (net.Conn, error) { - conn, err := d.TLSDialer.DialTLSContext(ctx, network, address) - if err != nil { - return nil, err - } - tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful - if !okay { - conn.Close() // we own the conn here - return nil, ErrNotTLSConn - } - return &httpTLSConnWithReadTimeout{tconn}, nil -} - -// httpConnWithReadTimeout enforces a read timeout for all HTTP -// connections. See https://github.com/ooni/probe/issues/1609. -type httpConnWithReadTimeout struct { - net.Conn -} - -// httpConnReadTimeout is the read timeout we apply to all HTTP -// conns (see https://github.com/ooni/probe/issues/1609). -// -// This timeout is meant as a fallback mechanism so that a stuck -// connection will _eventually_ fail. This is why it is set to -// a large value (300 seconds when writing this note). -// -// There should be other mechanisms to ensure that the code is -// lively: the context during the RoundTrip and iox.ReadAllContext -// when reading the body. They should kick in earlier. But we -// additionally want to avoid leaking a (parked?) connection and -// the corresponding goroutine, hence this large timeout. -// -// A future @bassosimone may understand this problem even better -// and possibly apply an even better fix to this issue. This -// will happen when we'll be able to further study the anomalies -// described in https://github.com/ooni/probe/issues/1609. -const httpConnReadTimeout = 300 * time.Second - -// Read implements Conn.Read. -func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) { - c.Conn.SetReadDeadline(time.Now().Add(httpConnReadTimeout)) - defer c.Conn.SetReadDeadline(time.Time{}) - return c.Conn.Read(b) -} - -// httpTLSConnWithReadTimeout enforces a read timeout for all HTTP -// connections. See https://github.com/ooni/probe/issues/1609. -type httpTLSConnWithReadTimeout struct { - TLSConn -} - -// Read implements Conn.Read. -func (c *httpTLSConnWithReadTimeout) Read(b []byte) (int, error) { - c.TLSConn.SetReadDeadline(time.Now().Add(httpConnReadTimeout)) - defer c.TLSConn.SetReadDeadline(time.Time{}) - return c.TLSConn.Read(b) -} - // NewHTTPTransportStdlib creates a new HTTPTransport using // the stdlib for DNS resolutions and TLS. // @@ -357,19 +144,3 @@ func NewHTTPClient(txp model.HTTPTransport) model.HTTPClient { func WrapHTTPClient(clnt model.HTTPClient) model.HTTPClient { return &httpClientErrWrapper{clnt} } - -type httpClientErrWrapper struct { - HTTPClient model.HTTPClient -} - -func (c *httpClientErrWrapper) Do(req *http.Request) (*http.Response, error) { - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, NewTopLevelGenericErrWrapper(err) - } - return resp, nil -} - -func (c *httpClientErrWrapper) CloseIdleConnections() { - c.HTTPClient.CloseIdleConnections() -} diff --git a/internal/netxlite/http_test.go b/internal/netxlite/http_test.go index 2fc204850f..e5d29a8c8d 100644 --- a/internal/netxlite/http_test.go +++ b/internal/netxlite/http_test.go @@ -3,13 +3,10 @@ package netxlite import ( "context" "errors" - "io" "net" "net/http" - "strings" "sync/atomic" "testing" - "time" "github.com/apex/log" "github.com/ooni/probe-cli/v3/internal/mocks" @@ -37,191 +34,6 @@ func TestNewHTTPTransportWithResolver(t *testing.T) { } } -func TestHTTPTransportErrWrapper(t *testing.T) { - t.Run("RoundTrip", func(t *testing.T) { - t.Run("with failure", func(t *testing.T) { - txp := &httpTransportErrWrapper{ - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return nil, io.EOF - }, - }, - } - resp, err := txp.RoundTrip(&http.Request{}) - var errWrapper *ErrWrapper - if !errors.As(err, &errWrapper) { - t.Fatal("the returned error is not an ErrWrapper") - } - if errWrapper.Failure != FailureEOFError { - t.Fatal("unexpected failure", errWrapper.Failure) - } - if resp != nil { - t.Fatal("expected nil response") - } - }) - - t.Run("with success", func(t *testing.T) { - expect := &http.Response{} - txp := &httpTransportErrWrapper{ - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return expect, nil - }, - }, - } - resp, err := txp.RoundTrip(&http.Request{}) - if err != nil { - t.Fatal(err) - } - if resp != expect { - t.Fatal("not the expected response") - } - }) - }) -} - -func TestHTTPTransportLogger(t *testing.T) { - t.Run("RoundTrip", func(t *testing.T) { - t.Run("with failure", func(t *testing.T) { - var count int - lo := &mocks.Logger{ - MockDebug: func(message string) { - count++ - }, - MockDebugf: func(format string, v ...interface{}) { - count++ - }, - } - txp := &httpTransportLogger{ - Logger: lo, - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return nil, io.EOF - }, - }, - } - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if !errors.Is(err, io.EOF) { - t.Fatal("not the error we expected") - } - if resp != nil { - t.Fatal("expected nil response here") - } - if count < 1 { - t.Fatal("no logs?!") - } - }) - - t.Run("with success", func(t *testing.T) { - var count int - lo := &mocks.Logger{ - MockDebug: func(message string) { - count++ - }, - MockDebugf: func(format string, v ...interface{}) { - count++ - }, - } - txp := &httpTransportLogger{ - Logger: lo, - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return &http.Response{ - Body: io.NopCloser(strings.NewReader("")), - Header: http.Header{ - "Server": []string{"antani/0.1.0"}, - }, - StatusCode: 200, - }, nil - }, - }, - } - client := &http.Client{Transport: txp} - req, err := http.NewRequest("GET", "https://www.google.com", nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("User-Agent", "miniooni/0.1.0-dev") - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - ReadAllContext(context.Background(), resp.Body) - resp.Body.Close() - if count < 1 { - t.Fatal("no logs?!") - } - }) - }) - - t.Run("CloseIdleConnections", func(t *testing.T) { - calls := &atomic.Int64{} - txp := &httpTransportLogger{ - HTTPTransport: &mocks.HTTPTransport{ - MockCloseIdleConnections: func() { - calls.Add(1) - }, - }, - Logger: log.Log, - } - txp.CloseIdleConnections() - if calls.Load() != 1 { - t.Fatal("not called") - } - }) -} - -func TestHTTPTransportConnectionsCloser(t *testing.T) { - t.Run("CloseIdleConnections", func(t *testing.T) { - var ( - calledTxp bool - calledDialer bool - calledTLS bool - ) - txp := &httpTransportConnectionsCloser{ - HTTPTransport: &mocks.HTTPTransport{ - MockCloseIdleConnections: func() { - calledTxp = true - }, - }, - Dialer: &mocks.Dialer{ - MockCloseIdleConnections: func() { - calledDialer = true - }, - }, - TLSDialer: &mocks.TLSDialer{ - MockCloseIdleConnections: func() { - calledTLS = true - }, - }, - } - txp.CloseIdleConnections() - if !calledDialer || !calledTLS || !calledTxp { - t.Fatal("not called") - } - }) - - t.Run("RoundTrip", func(t *testing.T) { - expected := errors.New("mocked error") - txp := &httpTransportConnectionsCloser{ - HTTPTransport: &mocks.HTTPTransport{ - MockRoundTrip: func(req *http.Request) (*http.Response, error) { - return nil, expected - }, - }, - } - client := &http.Client{Transport: txp} - resp, err := client.Get("https://www.google.com") - if !errors.Is(err, expected) { - t.Fatal("unexpected err", err) - } - if resp != nil { - t.Fatal("unexpected resp") - } - }) -} - func TestNewHTTPTransport(t *testing.T) { t.Run("works as intended with failing dialer", func(t *testing.T) { called := &atomic.Int64{} @@ -291,178 +103,6 @@ func TestNewHTTPTransport(t *testing.T) { }) } -func TestHTTPDialerWithReadTimeout(t *testing.T) { - t.Run("DialContext", func(t *testing.T) { - t.Run("on success", func(t *testing.T) { - var ( - calledWithZeroTime bool - calledWithNonZeroTime bool - ) - origConn := &mocks.Conn{ - MockSetReadDeadline: func(t time.Time) error { - switch t.IsZero() { - case true: - calledWithZeroTime = true - case false: - calledWithNonZeroTime = true - } - return nil - }, - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - } - d := &httpDialerWithReadTimeout{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return origConn, nil - }, - }, - } - ctx := context.Background() - conn, err := d.DialContext(ctx, "", "") - if err != nil { - t.Fatal(err) - } - if _, okay := conn.(*httpConnWithReadTimeout); !okay { - t.Fatal("invalid conn type") - } - if conn.(*httpConnWithReadTimeout).Conn != origConn { - t.Fatal("invalid origin conn") - } - b := make([]byte, 1024) - count, err := conn.Read(b) - if !errors.Is(err, io.EOF) { - t.Fatal("invalid error") - } - if count != 0 { - t.Fatal("invalid count") - } - if !calledWithZeroTime || !calledWithNonZeroTime { - t.Fatal("not called") - } - }) - - t.Run("on failure", func(t *testing.T) { - expected := errors.New("mocked error") - d := &httpDialerWithReadTimeout{ - Dialer: &mocks.Dialer{ - MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, expected - }, - }, - } - conn, err := d.DialContext(context.Background(), "", "") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - }) - }) -} - -func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { - t.Run("DialContext", func(t *testing.T) { - t.Run("on success", func(t *testing.T) { - var ( - calledWithZeroTime bool - calledWithNonZeroTime bool - ) - origConn := &mocks.TLSConn{ - Conn: mocks.Conn{ - MockSetReadDeadline: func(t time.Time) error { - switch t.IsZero() { - case true: - calledWithZeroTime = true - case false: - calledWithNonZeroTime = true - } - return nil - }, - MockRead: func(b []byte) (int, error) { - return 0, io.EOF - }, - }, - } - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return origConn, nil - }, - }, - } - ctx := context.Background() - conn, err := d.DialTLSContext(ctx, "", "") - if err != nil { - t.Fatal(err) - } - if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { - t.Fatal("invalid conn type") - } - if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { - t.Fatal("invalid origin conn") - } - b := make([]byte, 1024) - count, err := conn.Read(b) - if !errors.Is(err, io.EOF) { - t.Fatal("invalid error") - } - if count != 0 { - t.Fatal("invalid count") - } - if !calledWithZeroTime || !calledWithNonZeroTime { - t.Fatal("not called") - } - }) - - t.Run("on failure", func(t *testing.T) { - expected := errors.New("mocked error") - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return nil, expected - }, - }, - } - conn, err := d.DialTLSContext(context.Background(), "", "") - if !errors.Is(err, expected) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - }) - - t.Run("with invalid conn type", func(t *testing.T) { - var called bool - d := &httpTLSDialerWithReadTimeout{ - TLSDialer: &mocks.TLSDialer{ - MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { - return &mocks.Conn{ - MockClose: func() error { - called = true - return nil - }, - }, nil - }, - }, - } - conn, err := d.DialTLSContext(context.Background(), "", "") - if !errors.Is(err, ErrNotTLSConn) { - t.Fatal("not the error we expected") - } - if conn != nil { - t.Fatal("expected nil conn here") - } - if !called { - t.Fatal("not called") - } - }) - }) -} - func TestNewHTTPTransportStdlib(t *testing.T) { txp := NewHTTPTransportStdlib(log.Log) ctx, cancel := context.WithCancel(context.Background()) @@ -484,63 +124,6 @@ func TestNewHTTPTransportStdlib(t *testing.T) { txp.CloseIdleConnections() } -func TestHTTPClientErrWrapper(t *testing.T) { - t.Run("Do", func(t *testing.T) { - t.Run("with failure", func(t *testing.T) { - clnt := &httpClientErrWrapper{ - HTTPClient: &mocks.HTTPClient{ - MockDo: func(req *http.Request) (*http.Response, error) { - return nil, io.EOF - }, - }, - } - resp, err := clnt.Do(&http.Request{}) - var errWrapper *ErrWrapper - if !errors.As(err, &errWrapper) { - t.Fatal("the returned error is not an ErrWrapper") - } - if errWrapper.Failure != FailureEOFError { - t.Fatal("unexpected failure", errWrapper.Failure) - } - if resp != nil { - t.Fatal("expected nil response") - } - }) - - t.Run("with success", func(t *testing.T) { - expect := &http.Response{} - clnt := &httpClientErrWrapper{ - HTTPClient: &mocks.HTTPClient{ - MockDo: func(req *http.Request) (*http.Response, error) { - return expect, nil - }, - }, - } - resp, err := clnt.Do(&http.Request{}) - if err != nil { - t.Fatal(err) - } - if resp != expect { - t.Fatal("not the expected response") - } - }) - }) - - t.Run("CloseIdleConnections", func(t *testing.T) { - var called bool - child := &mocks.HTTPClient{ - MockCloseIdleConnections: func() { - called = true - }, - } - clnt := &httpClientErrWrapper{child} - clnt.CloseIdleConnections() - if !called { - t.Fatal("not called") - } - }) -} - func TestNewHTTPClientStdlib(t *testing.T) { clnt := NewHTTPClientStdlib(model.DiscardLogger) ewc, ok := clnt.(*httpClientErrWrapper) diff --git a/internal/netxlite/httpcloser.go b/internal/netxlite/httpcloser.go new file mode 100644 index 0000000000..0d51bd25ba --- /dev/null +++ b/internal/netxlite/httpcloser.go @@ -0,0 +1,36 @@ +package netxlite + +// +// Code to ensure we forward CloseIdleConnection calls +// + +import ( + "net/http" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// httpTransportConnectionsCloser is an HTTPTransport that +// correctly forwards CloseIdleConnections calls. +type httpTransportConnectionsCloser struct { + HTTPTransport model.HTTPTransport + Dialer model.Dialer + TLSDialer model.TLSDialer +} + +var _ model.HTTPTransport = &httpTransportConnectionsCloser{} + +func (txp *httpTransportConnectionsCloser) RoundTrip(req *http.Request) (*http.Response, error) { + return txp.HTTPTransport.RoundTrip(req) +} + +func (txp *httpTransportConnectionsCloser) Network() string { + return txp.HTTPTransport.Network() +} + +// CloseIdleConnections forwards the CloseIdleConnections calls. +func (txp *httpTransportConnectionsCloser) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() + txp.Dialer.CloseIdleConnections() + txp.TLSDialer.CloseIdleConnections() +} diff --git a/internal/netxlite/httpcloser_test.go b/internal/netxlite/httpcloser_test.go new file mode 100644 index 0000000000..71a28ab393 --- /dev/null +++ b/internal/netxlite/httpcloser_test.go @@ -0,0 +1,59 @@ +package netxlite + +import ( + "errors" + "net/http" + "testing" + + "github.com/ooni/probe-cli/v3/internal/mocks" +) + +func TestHTTPTransportConnectionsCloser(t *testing.T) { + t.Run("CloseIdleConnections", func(t *testing.T) { + var ( + calledTxp bool + calledDialer bool + calledTLS bool + ) + txp := &httpTransportConnectionsCloser{ + HTTPTransport: &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + calledTxp = true + }, + }, + Dialer: &mocks.Dialer{ + MockCloseIdleConnections: func() { + calledDialer = true + }, + }, + TLSDialer: &mocks.TLSDialer{ + MockCloseIdleConnections: func() { + calledTLS = true + }, + }, + } + txp.CloseIdleConnections() + if !calledDialer || !calledTLS || !calledTxp { + t.Fatal("not called") + } + }) + + t.Run("RoundTrip", func(t *testing.T) { + expected := errors.New("mocked error") + txp := &httpTransportConnectionsCloser{ + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, expected + }, + }, + } + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if !errors.Is(err, expected) { + t.Fatal("unexpected err", err) + } + if resp != nil { + t.Fatal("unexpected resp") + } + }) +} diff --git a/internal/netxlite/httperrwrap.go b/internal/netxlite/httperrwrap.go new file mode 100644 index 0000000000..989153c867 --- /dev/null +++ b/internal/netxlite/httperrwrap.go @@ -0,0 +1,50 @@ +package netxlite + +// +// Code to ensure we wrap errors +// + +import ( + "net/http" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// httpTransportErrWrapper is an HTTPTransport with error wrapping. +type httpTransportErrWrapper struct { + HTTPTransport model.HTTPTransport +} + +var _ model.HTTPTransport = &httpTransportErrWrapper{} + +func (txp *httpTransportErrWrapper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := txp.HTTPTransport.RoundTrip(req) + if err != nil { + return nil, NewTopLevelGenericErrWrapper(err) + } + return resp, nil +} + +func (txp *httpTransportErrWrapper) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() +} + +func (txp *httpTransportErrWrapper) Network() string { + return txp.HTTPTransport.Network() +} + +type httpClientErrWrapper struct { + HTTPClient model.HTTPClient +} + +func (c *httpClientErrWrapper) Do(req *http.Request) (*http.Response, error) { + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, NewTopLevelGenericErrWrapper(err) + } + return resp, nil +} + +func (c *httpClientErrWrapper) CloseIdleConnections() { + c.HTTPClient.CloseIdleConnections() +} diff --git a/internal/netxlite/httperrwrap_test.go b/internal/netxlite/httperrwrap_test.go new file mode 100644 index 0000000000..55e96f1675 --- /dev/null +++ b/internal/netxlite/httperrwrap_test.go @@ -0,0 +1,110 @@ +package netxlite + +import ( + "errors" + "io" + "net/http" + "testing" + + "github.com/ooni/probe-cli/v3/internal/mocks" +) + +func TestHTTPTransportErrWrapper(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + txp := &httpTransportErrWrapper{ + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + } + resp, err := txp.RoundTrip(&http.Request{}) + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("the returned error is not an ErrWrapper") + } + if errWrapper.Failure != FailureEOFError { + t.Fatal("unexpected failure", errWrapper.Failure) + } + if resp != nil { + t.Fatal("expected nil response") + } + }) + + t.Run("with success", func(t *testing.T) { + expect := &http.Response{} + txp := &httpTransportErrWrapper{ + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return expect, nil + }, + }, + } + resp, err := txp.RoundTrip(&http.Request{}) + if err != nil { + t.Fatal(err) + } + if resp != expect { + t.Fatal("not the expected response") + } + }) + }) +} + +func TestHTTPClientErrWrapper(t *testing.T) { + t.Run("Do", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + clnt := &httpClientErrWrapper{ + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + } + resp, err := clnt.Do(&http.Request{}) + var errWrapper *ErrWrapper + if !errors.As(err, &errWrapper) { + t.Fatal("the returned error is not an ErrWrapper") + } + if errWrapper.Failure != FailureEOFError { + t.Fatal("unexpected failure", errWrapper.Failure) + } + if resp != nil { + t.Fatal("expected nil response") + } + }) + + t.Run("with success", func(t *testing.T) { + expect := &http.Response{} + clnt := &httpClientErrWrapper{ + HTTPClient: &mocks.HTTPClient{ + MockDo: func(req *http.Request) (*http.Response, error) { + return expect, nil + }, + }, + } + resp, err := clnt.Do(&http.Request{}) + if err != nil { + t.Fatal(err) + } + if resp != expect { + t.Fatal("not the expected response") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + var called bool + child := &mocks.HTTPClient{ + MockCloseIdleConnections: func() { + called = true + }, + } + clnt := &httpClientErrWrapper{child} + clnt.CloseIdleConnections() + if !called { + t.Fatal("not called") + } + }) +} diff --git a/internal/netxlite/httplogger.go b/internal/netxlite/httplogger.go new file mode 100644 index 0000000000..048f52c900 --- /dev/null +++ b/internal/netxlite/httplogger.go @@ -0,0 +1,53 @@ +package netxlite + +// +// Code to ensure we log round trips +// + +import ( + "net/http" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// httpTransportLogger is an HTTPTransport with logging. +type httpTransportLogger struct { + // HTTPTransport is the underlying HTTP transport. + HTTPTransport model.HTTPTransport + + // Logger is the underlying logger. + Logger model.DebugLogger +} + +var _ model.HTTPTransport = &httpTransportLogger{} + +func (txp *httpTransportLogger) RoundTrip(req *http.Request) (*http.Response, error) { + txp.Logger.Debugf("> %s %s", req.Method, req.URL.String()) + for key, values := range req.Header { + for _, value := range values { + txp.Logger.Debugf("> %s: %s", key, value) + } + } + txp.Logger.Debug(">") + resp, err := txp.HTTPTransport.RoundTrip(req) + if err != nil { + txp.Logger.Debugf("< %s", err) + return nil, err + } + txp.Logger.Debugf("< %d", resp.StatusCode) + for key, values := range resp.Header { + for _, value := range values { + txp.Logger.Debugf("< %s: %s", key, value) + } + } + txp.Logger.Debug("<") + return resp, nil +} + +func (txp *httpTransportLogger) CloseIdleConnections() { + txp.HTTPTransport.CloseIdleConnections() +} + +func (txp *httpTransportLogger) Network() string { + return txp.HTTPTransport.Network() +} diff --git a/internal/netxlite/httplogger_test.go b/internal/netxlite/httplogger_test.go new file mode 100644 index 0000000000..d2e02c22f5 --- /dev/null +++ b/internal/netxlite/httplogger_test.go @@ -0,0 +1,106 @@ +package netxlite + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + + "github.com/apex/log" + "github.com/ooni/probe-cli/v3/internal/mocks" +) + +func TestHTTPTransportLogger(t *testing.T) { + t.Run("RoundTrip", func(t *testing.T) { + t.Run("with failure", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebug: func(message string) { + count++ + }, + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } + txp := &httpTransportLogger{ + Logger: lo, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return nil, io.EOF + }, + }, + } + client := &http.Client{Transport: txp} + resp, err := client.Get("https://www.google.com") + if !errors.Is(err, io.EOF) { + t.Fatal("not the error we expected") + } + if resp != nil { + t.Fatal("expected nil response here") + } + if count < 1 { + t.Fatal("no logs?!") + } + }) + + t.Run("with success", func(t *testing.T) { + var count int + lo := &mocks.Logger{ + MockDebug: func(message string) { + count++ + }, + MockDebugf: func(format string, v ...interface{}) { + count++ + }, + } + txp := &httpTransportLogger{ + Logger: lo, + HTTPTransport: &mocks.HTTPTransport{ + MockRoundTrip: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Body: io.NopCloser(strings.NewReader("")), + Header: http.Header{ + "Server": []string{"antani/0.1.0"}, + }, + StatusCode: 200, + }, nil + }, + }, + } + client := &http.Client{Transport: txp} + req, err := http.NewRequest("GET", "https://www.google.com", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "miniooni/0.1.0-dev") + resp, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + ReadAllContext(context.Background(), resp.Body) + resp.Body.Close() + if count < 1 { + t.Fatal("no logs?!") + } + }) + }) + + t.Run("CloseIdleConnections", func(t *testing.T) { + calls := &atomic.Int64{} + txp := &httpTransportLogger{ + HTTPTransport: &mocks.HTTPTransport{ + MockCloseIdleConnections: func() { + calls.Add(1) + }, + }, + Logger: log.Log, + } + txp.CloseIdleConnections() + if calls.Load() != 1 { + t.Fatal("not called") + } + }) +} diff --git a/internal/netxlite/httpstdlib.go b/internal/netxlite/httpstdlib.go new file mode 100644 index 0000000000..2f3063dbf9 --- /dev/null +++ b/internal/netxlite/httpstdlib.go @@ -0,0 +1,32 @@ +package netxlite + +// +// Code to adapt oohttp to the stdlib and the stdlib to our HTTP models +// + +import ( + "net/http" + + oohttp "github.com/ooni/oohttp" + "github.com/ooni/probe-cli/v3/internal/model" +) + +// stdlibTransport wraps oohttp.StdlibTransport to add .Network() +type httpTransportStdlib struct { + StdlibTransport *oohttp.StdlibTransport +} + +var _ model.HTTPTransport = &httpTransportStdlib{} + +func (txp *httpTransportStdlib) CloseIdleConnections() { + txp.StdlibTransport.CloseIdleConnections() +} + +func (txp *httpTransportStdlib) RoundTrip(req *http.Request) (*http.Response, error) { + return txp.StdlibTransport.RoundTrip(req) +} + +// Network implements HTTPTransport.Network. +func (txp *httpTransportStdlib) Network() string { + return "tcp" +} diff --git a/internal/netxlite/httptimeout.go b/internal/netxlite/httptimeout.go new file mode 100644 index 0000000000..9ae9f01dfb --- /dev/null +++ b/internal/netxlite/httptimeout.go @@ -0,0 +1,114 @@ +package netxlite + +// +// Code to ensure we have proper read timeouts (for reliability +// as described by https://github.com/ooni/probe/issues/1609) +// + +import ( + "context" + "errors" + "net" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// httpDialerWithReadTimeout enforces a read timeout for all HTTP +// connections. See https://github.com/ooni/probe/issues/1609. +type httpDialerWithReadTimeout struct { + Dialer model.Dialer +} + +var _ model.Dialer = &httpDialerWithReadTimeout{} + +func (d *httpDialerWithReadTimeout) CloseIdleConnections() { + d.Dialer.CloseIdleConnections() +} + +// DialContext implements Dialer.DialContext. +func (d *httpDialerWithReadTimeout) DialContext( + ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + return &httpConnWithReadTimeout{conn}, nil +} + +// httpTLSDialerWithReadTimeout enforces a read timeout for all HTTP +// connections. See https://github.com/ooni/probe/issues/1609. +type httpTLSDialerWithReadTimeout struct { + TLSDialer model.TLSDialer +} + +var _ model.TLSDialer = &httpTLSDialerWithReadTimeout{} + +func (d *httpTLSDialerWithReadTimeout) CloseIdleConnections() { + d.TLSDialer.CloseIdleConnections() +} + +// ErrNotTLSConn occur when an interface accepts a net.Conn but +// internally needs a TLSConn and you pass a net.Conn that doesn't +// implement TLSConn to such an interface. +var ErrNotTLSConn = errors.New("not a TLSConn") + +// DialTLSContext implements TLSDialer's DialTLSContext. +func (d *httpTLSDialerWithReadTimeout) DialTLSContext( + ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.TLSDialer.DialTLSContext(ctx, network, address) + if err != nil { + return nil, err + } + tconn, okay := conn.(TLSConn) // part of the contract but let's be graceful + if !okay { + conn.Close() // we own the conn here + return nil, ErrNotTLSConn + } + return &httpTLSConnWithReadTimeout{tconn}, nil +} + +// httpConnWithReadTimeout enforces a read timeout for all HTTP +// connections. See https://github.com/ooni/probe/issues/1609. +type httpConnWithReadTimeout struct { + net.Conn +} + +// httpConnReadTimeout is the read timeout we apply to all HTTP +// conns (see https://github.com/ooni/probe/issues/1609). +// +// This timeout is meant as a fallback mechanism so that a stuck +// connection will _eventually_ fail. This is why it is set to +// a large value (300 seconds when writing this note). +// +// There should be other mechanisms to ensure that the code is +// lively: the context during the RoundTrip and iox.ReadAllContext +// when reading the body. They should kick in earlier. But we +// additionally want to avoid leaking a (parked?) connection and +// the corresponding goroutine, hence this large timeout. +// +// A future @bassosimone may understand this problem even better +// and possibly apply an even better fix to this issue. This +// will happen when we'll be able to further study the anomalies +// described in https://github.com/ooni/probe/issues/1609. +const httpConnReadTimeout = 300 * time.Second + +// Read implements Conn.Read. +func (c *httpConnWithReadTimeout) Read(b []byte) (int, error) { + c.Conn.SetReadDeadline(time.Now().Add(httpConnReadTimeout)) + defer c.Conn.SetReadDeadline(time.Time{}) + return c.Conn.Read(b) +} + +// httpTLSConnWithReadTimeout enforces a read timeout for all HTTP +// connections. See https://github.com/ooni/probe/issues/1609. +type httpTLSConnWithReadTimeout struct { + TLSConn +} + +// Read implements Conn.Read. +func (c *httpTLSConnWithReadTimeout) Read(b []byte) (int, error) { + c.TLSConn.SetReadDeadline(time.Now().Add(httpConnReadTimeout)) + defer c.TLSConn.SetReadDeadline(time.Time{}) + return c.TLSConn.Read(b) +} diff --git a/internal/netxlite/httptimeout_test.go b/internal/netxlite/httptimeout_test.go new file mode 100644 index 0000000000..44a79bc290 --- /dev/null +++ b/internal/netxlite/httptimeout_test.go @@ -0,0 +1,184 @@ +package netxlite + +import ( + "context" + "errors" + "io" + "net" + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/mocks" +) + +func TestHTTPDialerWithReadTimeout(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.Conn{ + MockSetReadDeadline: func(t time.Time) error { + switch t.IsZero() { + case true: + calledWithZeroTime = true + case false: + calledWithNonZeroTime = true + } + return nil + }, + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + } + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpConnWithReadTimeout).Conn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &httpDialerWithReadTimeout{ + Dialer: &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + }, + } + conn, err := d.DialContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + }) +} + +func TestHTTPTLSDialerWithReadTimeout(t *testing.T) { + t.Run("DialContext", func(t *testing.T) { + t.Run("on success", func(t *testing.T) { + var ( + calledWithZeroTime bool + calledWithNonZeroTime bool + ) + origConn := &mocks.TLSConn{ + Conn: mocks.Conn{ + MockSetReadDeadline: func(t time.Time) error { + switch t.IsZero() { + case true: + calledWithZeroTime = true + case false: + calledWithNonZeroTime = true + } + return nil + }, + MockRead: func(b []byte) (int, error) { + return 0, io.EOF + }, + }, + } + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return origConn, nil + }, + }, + } + ctx := context.Background() + conn, err := d.DialTLSContext(ctx, "", "") + if err != nil { + t.Fatal(err) + } + if _, okay := conn.(*httpTLSConnWithReadTimeout); !okay { + t.Fatal("invalid conn type") + } + if conn.(*httpTLSConnWithReadTimeout).TLSConn != origConn { + t.Fatal("invalid origin conn") + } + b := make([]byte, 1024) + count, err := conn.Read(b) + if !errors.Is(err, io.EOF) { + t.Fatal("invalid error") + } + if count != 0 { + t.Fatal("invalid count") + } + if !calledWithZeroTime || !calledWithNonZeroTime { + t.Fatal("not called") + } + }) + + t.Run("on failure", func(t *testing.T) { + expected := errors.New("mocked error") + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return nil, expected + }, + }, + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, expected) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + }) + + t.Run("with invalid conn type", func(t *testing.T) { + var called bool + d := &httpTLSDialerWithReadTimeout{ + TLSDialer: &mocks.TLSDialer{ + MockDialTLSContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return &mocks.Conn{ + MockClose: func() error { + called = true + return nil + }, + }, nil + }, + }, + } + conn, err := d.DialTLSContext(context.Background(), "", "") + if !errors.Is(err, ErrNotTLSConn) { + t.Fatal("not the error we expected") + } + if conn != nil { + t.Fatal("expected nil conn here") + } + if !called { + t.Fatal("not called") + } + }) + }) +}