diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go index fdba30505b..8bdf9c2111 100644 --- a/internal/rpc/rpc.go +++ b/internal/rpc/rpc.go @@ -20,56 +20,22 @@ import ( "github.com/TBD54566975/ftl/internal/log" ) -// InitialiseClients initialises global HTTP clients used by the RPC system. +var ( + authenticators map[string]string + allowInsecure bool +) + +// InitialiseClients initialises parameters for the HTTP clients used by the RPC system. +// +// # To avoid caching issues these clients are not global, but a created for each endpoint // // "authenticators" are authenticator executables to use for each endpoint. The key is the URL of the endpoint, the // value is the path to the authenticator executable. // // "allowInsecure" skips certificate verification, making TLS susceptible to machine-in-the-middle attacks. -func InitialiseClients(authenticators map[string]string, allowInsecure bool) { - // We can't have a client-wide timeout because it also applies to - // streaming RPCs, timing them out. - h2cClient = &http.Client{ - Transport: authn.Transport(&http2.Transport{ - AllowHTTP: true, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: allowInsecure, // #nosec G402 - }, - DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { - conn, err := dialer.Dial(network, addr) - return conn, err - }, - }, authenticators), - } - tlsClient = &http.Client{ - Transport: authn.Transport(&http2.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: allowInsecure, // #nosec G402 - }, - DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) { - tlsDialer := tls.Dialer{Config: config, NetDialer: dialer} - conn, err := tlsDialer.DialContext(ctx, network, addr) - return conn, err - }, - }, authenticators), - } - - // Use a separate client for HTTP/1.1 with TLS. - http1TLSClient = &http.Client{ - Transport: authn.Transport(&http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: allowInsecure, // #nosec G402 - }, - DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - logger := log.FromContext(ctx) - logger.Debugf("HTTP/1.1 connecting to %s %s", network, addr) - - tlsDialer := tls.Dialer{NetDialer: dialer} - conn, err := tlsDialer.DialContext(ctx, network, addr) - return conn, fmt.Errorf("HTTP/1.1 TLS dial failed: %w", err) - }, - }, authenticators), - } +func InitialiseClients(authenticatorsParam map[string]string, allowInsecureParam bool) { + authenticators = authenticatorsParam + allowInsecure = allowInsecureParam } func init() { @@ -80,10 +46,6 @@ var ( dialer = &net.Dialer{ Timeout: time.Second * 10, } - h2cClient *http.Client - tlsClient *http.Client - // Temporary client for HTTP/1.1 with TLS to help with debugging. - http1TLSClient *http.Client ) type Pingable interface { @@ -92,19 +54,55 @@ type Pingable interface { // GetHTTPClient returns a HTTP client usable for the given URL. func GetHTTPClient(url string) *http.Client { - if h2cClient == nil { - panic("rpc.InitialiseClients() must be called before GetHTTPClient()") - } // TEMP_GRPC_HTTP1_ONLY set to non blank will use http1TLSClient if os.Getenv("TEMP_GRPC_HTTP1_ONLY") != "" { - return http1TLSClient + return &http.Client{ + Transport: authn.Transport(&http.Transport{ + IdleConnTimeout: time.Minute, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: allowInsecure, // #nosec G402 + }, + DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + logger := log.FromContext(ctx) + logger.Debugf("HTTP/1.1 connecting to %s %s", network, addr) + + tlsDialer := tls.Dialer{NetDialer: dialer} + conn, err := tlsDialer.DialContext(ctx, network, addr) + return conn, fmt.Errorf("HTTP/1.1 TLS dial failed: %w", err) + }, + }, authenticators), + } } if strings.HasPrefix(url, "http://") { - return h2cClient + return &http.Client{ + Transport: authn.Transport(&http2.Transport{ + IdleConnTimeout: time.Minute, + AllowHTTP: true, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: allowInsecure, // #nosec G402 + }, + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + conn, err := dialer.Dial(network, addr) + return conn, err + }, + }, authenticators), + } + } + return &http.Client{ + Transport: authn.Transport(&http2.Transport{ + IdleConnTimeout: time.Minute, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: allowInsecure, // #nosec G402 + }, + DialTLSContext: func(ctx context.Context, network, addr string, config *tls.Config) (net.Conn, error) { + tlsDialer := tls.Dialer{Config: config, NetDialer: dialer} + conn, err := tlsDialer.DialContext(ctx, network, addr) + return conn, err + }, + }, authenticators), } - return tlsClient } // ClientFactory is a function that creates a new client and is typically one of