diff --git a/rpc/client/connect.go b/rpc/client/connect.go index cdbbbea8..5ddca4de 100644 --- a/rpc/client/connect.go +++ b/rpc/client/connect.go @@ -31,7 +31,7 @@ func (c *Client) createGRPCClient(ctx context.Context) (err error) { var errInvalidEndpoint = errors.New("invalid endpoint options") -func (c *Client) openGRPCConn(ctx context.Context) error { +func (c *Client) openGRPCConn(ctx context.Context, extraDialOpts ...grpcstd.DialOption) error { if c.conn != nil { return nil } @@ -51,10 +51,16 @@ func (c *Client) openGRPCConn(ctx context.Context) error { dialCtx, cancel := context.WithTimeout(ctx, c.dialTimeout) var err error - c.conn, err = grpcstd.DialContext(dialCtx, c.addr, + dialOpts := make([]grpcstd.DialOption, 0, 2+len(extraDialOpts)) + dialOpts = append(dialOpts, grpcstd.WithTransportCredentials(creds), grpcstd.WithBlock(), ) + if extraDialOpts != nil { + dialOpts = append(dialOpts, extraDialOpts...) + } + + c.conn, err = grpcstd.DialContext(dialCtx, c.addr, dialOpts...) cancel() diff --git a/rpc/client/connect_test.go b/rpc/client/connect_test.go new file mode 100644 index 00000000..f76f7f57 --- /dev/null +++ b/rpc/client/connect_test.go @@ -0,0 +1,35 @@ +package client + +import ( + "context" + "crypto/tls" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" +) + +func TestClient_Init(t *testing.T) { + t.Run("TLS handshake failure", func(t *testing.T) { + lis := bufconn.Listen(1024) // size does not matter in this test + + srv := grpc.NewServer() + t.Cleanup(srv.Stop) + go func() { _ = srv.Serve(lis) }() + + c := New(WithNetworkURIAddress("grpcs://any:54321", new(tls.Config))...) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + t.Cleanup(cancel) + + err := c.openGRPCConn(ctx, grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) { + return lis.DialContext(ctx) + })) + // error is not wrapped properly, so we can do nothing more to check it. + // Text from stdlib tls.Conn.HandshakeContext. + require.ErrorContains(t, err, "first record does not look like a TLS handshake") + }) +}