diff --git a/client.go b/client.go index 53822adc..e9bc68ef 100644 --- a/client.go +++ b/client.go @@ -10,12 +10,11 @@ import ( "encoding/binary" "encoding/json" "fmt" + "net" "sync" "time" - gzk "github.com/go-zookeeper/zk" log "github.com/sirupsen/logrus" - "github.com/tsuna/gohbase/compression" "github.com/tsuna/gohbase/hrpc" "github.com/tsuna/gohbase/pb" @@ -99,15 +98,15 @@ type client struct { closeOnce sync.Once newRegionClientFn func(string, region.ClientType, int, time.Duration, - string, time.Duration, compression.Codec, region.Dialer) hrpc.RegionClient + string, time.Duration, compression.Codec, func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient compressionCodec compression.Codec - // zkDialer is passed through to Zk Connect() to configure custom connection settings - zkDialer gzk.Dialer + // zkDialer is used in the zkClient to connect to the quorum + zkDialer func(ctx context.Context, network, addr string) (net.Conn, error) // regionDialer is passed into the region client to connect to hbase in a custom way, // such as SOCKS proxy. - regionDialer region.Dialer + regionDialer func(ctx context.Context, network, addr string) (net.Conn, error) } // NewClient creates a new HBase client. @@ -279,7 +278,7 @@ func CompressionCodec(codec string) Option { // ZooKeeperDialer will return an option to pass the given dialer function // into the ZooKeeper client Connect() call, which allows for customizing // network connections. -func ZooKeeperDialer(dialer gzk.Dialer) Option { +func ZooKeeperDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Option { return func(c *client) { c.zkDialer = dialer } @@ -287,7 +286,7 @@ func ZooKeeperDialer(dialer gzk.Dialer) Option { // RegionDialer will return an option that uses the specified Dialer for // connecting to region servers. This allows for connecting through proxies. -func RegionDialer(dialer region.Dialer) Option { +func RegionDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Option { return func(c *client) { c.regionDialer = dialer } diff --git a/mockrc_test.go b/mockrc_test.go index 5d819c09..c35ecf0c 100644 --- a/mockrc_test.go +++ b/mockrc_test.go @@ -9,6 +9,7 @@ import ( "bytes" "context" "fmt" + "net" "sync" "sync/atomic" "time" @@ -177,7 +178,7 @@ func init() { func newMockRegionClient(addr string, ctype region.ClientType, queueSize int, flushInterval time.Duration, effectiveUser string, - readTimeout time.Duration, codec compression.Codec, dialer region.Dialer) hrpc.RegionClient { + readTimeout time.Duration, codec compression.Codec, dialer func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient { m.Lock() clients[addr]++ m.Unlock() diff --git a/region/client.go b/region/client.go index ec5580fc..acd15703 100644 --- a/region/client.go +++ b/region/client.go @@ -194,7 +194,7 @@ type client struct { compressor *compressor // dialer is used to connect to region servers in non-standard ways - dialer Dialer + dialer func(ctx context.Context, network, addr string) (net.Conn, error) } // QueueRPC will add an rpc call to the queue for processing by the writer goroutine diff --git a/region/new.go b/region/new.go index ea34b0a4..4ed37f2f 100644 --- a/region/new.go +++ b/region/new.go @@ -18,16 +18,10 @@ import ( "github.com/tsuna/gohbase/hrpc" ) -// Dialer is used to connect to region servers. net.Dialer conforms to this -// interface, which is just the subset of it that we use. -type Dialer interface { - DialContext(ctx context.Context, net, addr string) (net.Conn, error) -} - // NewClient creates a new RegionClient. func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.Duration, effectiveUser string, readTimeout time.Duration, codec compression.Codec, - dialer Dialer) hrpc.RegionClient { + dialer func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient { c := &client{ addr: addr, ctype: ctype, @@ -48,7 +42,7 @@ func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time. c.dialer = dialer } else { var d net.Dialer - c.dialer = &d + c.dialer = d.DialContext } return c @@ -57,7 +51,7 @@ func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time. func (c *client) Dial(ctx context.Context) error { c.dialOnce.Do(func() { var err error - c.conn, err = c.dialer.DialContext(ctx, "tcp", c.addr) + c.conn, err = c.dialer(ctx, "tcp", c.addr) if err != nil { c.fail(fmt.Errorf("failed to dial RegionServer: %s", err)) return diff --git a/rpc.go b/rpc.go index 0654fe30..e9166a38 100644 --- a/rpc.go +++ b/rpc.go @@ -16,7 +16,6 @@ import ( "time" log "github.com/sirupsen/logrus" - "github.com/tsuna/gohbase/hrpc" "github.com/tsuna/gohbase/internal/observability" "github.com/tsuna/gohbase/region" diff --git a/rpc_test.go b/rpc_test.go index e3eb9692..b7bbe45c 100644 --- a/rpc_test.go +++ b/rpc_test.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "math/rand" + "net" "reflect" "strconv" "strings" @@ -20,7 +21,6 @@ import ( "github.com/golang/mock/gomock" "github.com/prometheus/client_golang/prometheus" - "github.com/tsuna/gohbase/compression" "github.com/tsuna/gohbase/hrpc" "github.com/tsuna/gohbase/pb" @@ -302,7 +302,7 @@ func TestEstablishRegionDialFail(t *testing.T) { newRegionClientFnCallCount := 0 c.newRegionClientFn = func(_ string, _ region.ClientType, _ int, _ time.Duration, - _ string, _ time.Duration, _ compression.Codec, _ region.Dialer) hrpc.RegionClient { + _ string, _ time.Duration, _ compression.Codec, _ func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient { var rc hrpc.RegionClient if newRegionClientFnCallCount == 0 { rc = rcFailDial diff --git a/zk/client.go b/zk/client.go index 75abfd5f..2687199b 100644 --- a/zk/client.go +++ b/zk/client.go @@ -7,6 +7,7 @@ package zk import ( + "context" "encoding/binary" "fmt" "net" @@ -58,11 +59,11 @@ type Client interface { type client struct { zks []string sessionTimeout time.Duration - dialer zk.Dialer + dialer func(ctx context.Context, network, addr string) (net.Conn, error) } // NewClient establishes connection to zookeeper and returns the client -func NewClient(zkquorum string, st time.Duration, dialer zk.Dialer) Client { +func NewClient(zkquorum string, st time.Duration, dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Client { return &client{ zks: strings.Split(zkquorum, ","), sessionTimeout: st, @@ -75,7 +76,7 @@ func (c *client) LocateResource(resource ResourceName) (string, error) { var conn *zk.Conn var err error if c.dialer != nil { - conn, _, err = zk.Connect(c.zks, c.sessionTimeout, zk.WithDialer(c.dialer)) + conn, _, err = zk.Connect(c.zks, c.sessionTimeout, zk.WithDialer(makeZKDialer(c.dialer))) } else { conn, _, err = zk.Connect(c.zks, c.sessionTimeout) } @@ -124,3 +125,11 @@ func (c *client) LocateResource(resource ResourceName) (string, error) { } return net.JoinHostPort(*server.HostName, fmt.Sprint(*server.Port)), nil } + +func makeZKDialer(ctxDialer func(ctx context.Context, network, addr string) (net.Conn, error)) zk.Dialer { + return func(network, addr string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return ctxDialer(ctx, network, addr) + } +}