From 6b872e8806a513b107659897dc20fdc2b92c3dfe Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Tue, 15 Oct 2024 18:04:50 -0700 Subject: [PATCH] Do not fail creating the provisioner HTTP client This commit avoids an error starting the CA if the `http.DefaultTransport` is not an `*http.Transport`. If the DefaultTransport is overwritten, the newHTTPClient method will return a simple *http.Client. With an *http.Transport, it will return a client that trusts the system certificate pool and the CA roots. --- authority/http_client.go | 40 +++++++++++++++++------------------ authority/http_client_test.go | 15 +++++++++++++ 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/authority/http_client.go b/authority/http_client.go index 377db8ae5..ff61e45fe 100644 --- a/authority/http_client.go +++ b/authority/http_client.go @@ -7,28 +7,28 @@ import ( "net/http" ) -// newHTTPClient returns an HTTP client that trusts the system cert pool and the -// given roots. +// newHTTPClient will return an HTTP client that trusts the system cert pool and +// the given roots, but only if the http.DefaultTransport is an *http.Transport. +// If not, it will return the default HTTP client. func newHTTPClient(roots ...*x509.Certificate) (*http.Client, error) { - pool, err := x509.SystemCertPool() - if err != nil { - return nil, fmt.Errorf("error initializing http client: %w", err) - } - for _, crt := range roots { - pool.AddCert(crt) - } + if tr, ok := http.DefaultTransport.(*http.Transport); ok { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("error initializing http client: %w", err) + } + for _, crt := range roots { + pool.AddCert(crt) + } - tr, ok := http.DefaultTransport.(*http.Transport) - if !ok { - return nil, fmt.Errorf("error initializing http client: type is not *http.Transport") - } - tr = tr.Clone() - tr.TLSClientConfig = &tls.Config{ - MinVersion: tls.VersionTLS12, - RootCAs: pool, + tr = tr.Clone() + tr.TLSClientConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: pool, + } + return &http.Client{ + Transport: tr, + }, nil } - return &http.Client{ - Transport: tr, - }, nil + return &http.Client{}, nil } diff --git a/authority/http_client_test.go b/authority/http_client_test.go index b7698e94c..979c884df 100644 --- a/authority/http_client_test.go +++ b/authority/http_client_test.go @@ -102,4 +102,19 @@ func Test_newHTTPClient(t *testing.T) { assert.Error(t, err) }) }) + + t.Run("custom transport", func(t *testing.T) { + tmp := http.DefaultTransport + t.Cleanup(func() { + http.DefaultTransport = tmp + }) + transport := struct { + http.RoundTripper + }{http.DefaultTransport} + http.DefaultTransport = transport + + client, err := newHTTPClient(auth.rootX509Certs...) + assert.NoError(t, err) + assert.Equal(t, &http.Client{}, client) + }) }