Skip to content

Commit

Permalink
Update PR per comments
Browse files Browse the repository at this point in the history
Change dialers to be unnamed function types
  • Loading branch information
marcinromaszewicz committed Apr 25, 2024
1 parent d2d4f4d commit a72a2de
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 25 deletions.
15 changes: 7 additions & 8 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@ import (
"encoding/binary"
"encoding/json"
"fmt"
"net"
"sync"
"time"

gzk "github.com/go-zookeeper/zk"
log "github.com/sirupsen/logrus"

"github.com/tsuna/gohbase/compression"
"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
Expand Down Expand Up @@ -99,15 +98,15 @@ type client struct {
closeOnce sync.Once

newRegionClientFn func(string, region.ClientType, int, time.Duration,
string, time.Duration, compression.Codec, region.Dialer) hrpc.RegionClient
string, time.Duration, compression.Codec, func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient

compressionCodec compression.Codec

// zkDialer is passed through to Zk Connect() to configure custom connection settings
zkDialer gzk.Dialer
// zkDialer is used in the zkClient to connect to the quorum
zkDialer func(ctx context.Context, network, addr string) (net.Conn, error)
// regionDialer is passed into the region client to connect to hbase in a custom way,
// such as SOCKS proxy.
regionDialer region.Dialer
regionDialer func(ctx context.Context, network, addr string) (net.Conn, error)
}

// NewClient creates a new HBase client.
Expand Down Expand Up @@ -279,15 +278,15 @@ func CompressionCodec(codec string) Option {
// ZooKeeperDialer will return an option to pass the given dialer function
// into the ZooKeeper client Connect() call, which allows for customizing
// network connections.
func ZooKeeperDialer(dialer gzk.Dialer) Option {
func ZooKeeperDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(c *client) {
c.zkDialer = dialer
}
}

// RegionDialer will return an option that uses the specified Dialer for
// connecting to region servers. This allows for connecting through proxies.
func RegionDialer(dialer region.Dialer) Option {
func RegionDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Option {
return func(c *client) {
c.regionDialer = dialer
}
Expand Down
3 changes: 2 additions & 1 deletion mockrc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"bytes"
"context"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -177,7 +178,7 @@ func init() {

func newMockRegionClient(addr string, ctype region.ClientType, queueSize int,
flushInterval time.Duration, effectiveUser string,
readTimeout time.Duration, codec compression.Codec, dialer region.Dialer) hrpc.RegionClient {
readTimeout time.Duration, codec compression.Codec, dialer func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient {
m.Lock()
clients[addr]++
m.Unlock()
Expand Down
2 changes: 1 addition & 1 deletion region/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ type client struct {
compressor *compressor

// dialer is used to connect to region servers in non-standard ways
dialer Dialer
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
}

// QueueRPC will add an rpc call to the queue for processing by the writer goroutine
Expand Down
12 changes: 3 additions & 9 deletions region/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,10 @@ import (
"github.com/tsuna/gohbase/hrpc"
)

// Dialer is used to connect to region servers. net.Dialer conforms to this
// interface, which is just the subset of it that we use.
type Dialer interface {
DialContext(ctx context.Context, net, addr string) (net.Conn, error)
}

// NewClient creates a new RegionClient.
func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.Duration,
effectiveUser string, readTimeout time.Duration, codec compression.Codec,
dialer Dialer) hrpc.RegionClient {
dialer func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient {
c := &client{
addr: addr,
ctype: ctype,
Expand All @@ -48,7 +42,7 @@ func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.
c.dialer = dialer
} else {
var d net.Dialer
c.dialer = &d
c.dialer = d.DialContext
}

return c
Expand All @@ -57,7 +51,7 @@ func NewClient(addr string, ctype ClientType, queueSize int, flushInterval time.
func (c *client) Dial(ctx context.Context) error {
c.dialOnce.Do(func() {
var err error
c.conn, err = c.dialer.DialContext(ctx, "tcp", c.addr)
c.conn, err = c.dialer(ctx, "tcp", c.addr)
if err != nil {
c.fail(fmt.Errorf("failed to dial RegionServer: %s", err))
return
Expand Down
1 change: 0 additions & 1 deletion rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"time"

log "github.com/sirupsen/logrus"

"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/internal/observability"
"github.com/tsuna/gohbase/region"
Expand Down
4 changes: 2 additions & 2 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"math/rand"
"net"
"reflect"
"strconv"
"strings"
Expand All @@ -20,7 +21,6 @@ import (

"github.com/golang/mock/gomock"
"github.com/prometheus/client_golang/prometheus"

"github.com/tsuna/gohbase/compression"
"github.com/tsuna/gohbase/hrpc"
"github.com/tsuna/gohbase/pb"
Expand Down Expand Up @@ -302,7 +302,7 @@ func TestEstablishRegionDialFail(t *testing.T) {

newRegionClientFnCallCount := 0
c.newRegionClientFn = func(_ string, _ region.ClientType, _ int, _ time.Duration,
_ string, _ time.Duration, _ compression.Codec, _ region.Dialer) hrpc.RegionClient {
_ string, _ time.Duration, _ compression.Codec, _ func(ctx context.Context, network, addr string) (net.Conn, error)) hrpc.RegionClient {
var rc hrpc.RegionClient
if newRegionClientFnCallCount == 0 {
rc = rcFailDial
Expand Down
15 changes: 12 additions & 3 deletions zk/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package zk

import (
"context"
"encoding/binary"
"fmt"
"net"
Expand Down Expand Up @@ -58,11 +59,11 @@ type Client interface {
type client struct {
zks []string
sessionTimeout time.Duration
dialer zk.Dialer
dialer func(ctx context.Context, network, addr string) (net.Conn, error)
}

// NewClient establishes connection to zookeeper and returns the client
func NewClient(zkquorum string, st time.Duration, dialer zk.Dialer) Client {
func NewClient(zkquorum string, st time.Duration, dialer func(ctx context.Context, network, addr string) (net.Conn, error)) Client {
return &client{
zks: strings.Split(zkquorum, ","),
sessionTimeout: st,
Expand All @@ -75,7 +76,7 @@ func (c *client) LocateResource(resource ResourceName) (string, error) {
var conn *zk.Conn
var err error
if c.dialer != nil {
conn, _, err = zk.Connect(c.zks, c.sessionTimeout, zk.WithDialer(c.dialer))
conn, _, err = zk.Connect(c.zks, c.sessionTimeout, zk.WithDialer(makeZKDialer(c.dialer)))
} else {
conn, _, err = zk.Connect(c.zks, c.sessionTimeout)
}
Expand Down Expand Up @@ -124,3 +125,11 @@ func (c *client) LocateResource(resource ResourceName) (string, error) {
}
return net.JoinHostPort(*server.HostName, fmt.Sprint(*server.Port)), nil
}

func makeZKDialer(ctxDialer func(ctx context.Context, network, addr string) (net.Conn, error)) zk.Dialer {
return func(network, addr string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return ctxDialer(ctx, network, addr)
}
}

0 comments on commit a72a2de

Please sign in to comment.