Skip to content

Commit

Permalink
Tracks the active usage counts of each connection.
Browse files Browse the repository at this point in the history
Only GC those with 0 activeUsage.
Reduces the purgeTime to 10 minutes.

PiperOrigin-RevId: 657590714
Change-Id: I32ed9583f0f642c49a94a73607a296f7971e74e9
  • Loading branch information
Sax Authors authored and copybara-github committed Jul 30, 2024
1 parent 76a6ca6 commit af3f4e8
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 50 deletions.
111 changes: 65 additions & 46 deletions saxml/client/go/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"

log "github.com/golang/glog"
Expand All @@ -32,23 +33,35 @@ import (
const (
sleepTime = 5 * time.Second
fastPurgeTime = 10 * time.Second
purgeTime = 60 * time.Minute
purgeTime = 10 * time.Minute
dialTimeout = 2 * time.Second
)

type conn struct {
// Conn represents a rpc connection to a target.
type Conn struct {
client *grpc.ClientConn
lastAccTime time.Time
activeUsage atomic.Int32
}

// Client returns the underlying grpc connection.
func (c *Conn) Client() *grpc.ClientConn {
return c.client
}

// Release decrements the usage count by 1.
func (c *Conn) Release() {
c.activeUsage.Add(-1)
}

type connTable struct {
mu sync.RWMutex
// Mapping between modelet address to modelet connection and last access time.
table map[string]*conn
table map[string]*Conn
}

func newConnTable() *connTable {
c := &connTable{table: make(map[string]*conn)}
c := &connTable{table: make(map[string]*Conn)}

go func() {
// A loop that clears the connections based on last access time.
Expand All @@ -59,22 +72,26 @@ func newConnTable() *connTable {
c.mu.Lock()
log.V(2).Infof("clearing connTable with %d connections: %v %v\n", len(c.table), purgeTime, fastPurgeTime)
for addr, conn := range c.table {
shouldClose := false
if conn.client.GetState() == connectivity.Ready {
// If the grpc connection is ready but has not been used for quite a while, close it.
if conn.lastAccTime.Before(now.Add(-purgeTime)) {
if conn.activeUsage.Load() != 0 {
log.V(2).Infof("active connection (%d) %s", conn.activeUsage.Load(), addr)
} else {
shouldClose := false
if conn.client.GetState() == connectivity.Ready {
// If the grpc connection is ready but has not been used for quite a while, close it.
if conn.lastAccTime.Before(now.Add(-purgeTime)) {
shouldClose = true
log.V(3).Infof("conneTable removed idle connection to addr %s\n", addr)
}
} else if conn.lastAccTime.Before(now.Add(-fastPurgeTime)) {
// If the grpc connection is _not_ ready and has not been used recently, close it.
shouldClose = true
log.V(3).Infof("conneTable removed idle connection to addr %s\n", addr)
log.V(3).Infof("conneTable removed addr %s\n", addr)
}
if shouldClose {
log.Infof("connTable close %s", addr)
conn.client.Close()
delete(c.table, addr) // It's safe to delete and traverse.
}
} else if conn.lastAccTime.Before(now.Add(-fastPurgeTime)) {
// If the grpc connection is _not_ ready and has not been used recently, close it.
shouldClose = true
log.V(3).Infof("conneTable removed addr %s\n", addr)
}
if shouldClose {
log.Infof("connTable close %s", addr)
conn.client.Close()
delete(c.table, addr) // It's safe to delete and traverse.
}
}
log.V(2).Infof("after clearing connTable with %v/%v, there are %d connections\n", purgeTime, fastPurgeTime, len(c.table))
Expand All @@ -87,30 +104,29 @@ func newConnTable() *connTable {
}

// checkAndGet checks the existence of connecton for an addrress and returns connection.
// The returned boolean indicates if connection is found.
func (t *connTable) checkAndGet(addr string) (*grpc.ClientConn, bool) {
func (t *connTable) checkAndGet(addr string) *Conn {
t.mu.Lock()
defer t.mu.Unlock()
connection, found := t.table[addr]
if found && connection.client != nil {
if found {
connection.lastAccTime = time.Now()
return connection.client, true
connection.activeUsage.Add(1)
return connection
}
return nil, false
return nil
}

func (t *connTable) getOrCreate(ctx context.Context, addr string) (*grpc.ClientConn, error) {
existingClient, found := t.checkAndGet(addr)
if found && existingClient != nil {
return existingClient, nil
func (t *connTable) getOrCreate(ctx context.Context, addr string) (*Conn, error) {
existingConn := t.checkAndGet(addr)
if existingConn != nil {
return existingConn, nil
}

// Couldn't find connection. Create a new one.
var newClient *grpc.ClientConn
var err error
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
newClient, err = env.Get().DialContext(ctx, addr)
newClient, err := env.Get().DialContext(ctx, addr)
if err != nil || newClient == nil {
log.V(3).Infof("getOrCreate create connection for %s failed due to %v\n", addr, err)
if errors.IsDeadlineExceeded(err) {
Expand All @@ -121,23 +137,23 @@ func (t *connTable) getOrCreate(ctx context.Context, addr string) (*grpc.ClientC

t.mu.Lock()
defer t.mu.Unlock()
existingConn, found := t.table[addr] // Re-check since Dial was not locked.
if found && existingConn.client != nil {
conn, found := t.table[addr] // Re-check since Dial was not locked.
if found {
newClient.Close()
existingConn.lastAccTime = time.Now()
return existingConn.client, nil
} else {
conn = &Conn{client: newClient}
t.table[addr] = conn
}

newConn := &conn{client: newClient, lastAccTime: time.Now()}
t.table[addr] = newConn
return newClient, nil
conn.lastAccTime = time.Now()
conn.activeUsage.Add(1)
return conn, nil
}

var globalConnTable *connTable = newConnTable()

// Factory manages connections to a given model.
type Factory interface {
GetOrCreate(ctx context.Context) (*grpc.ClientConn, error)
GetOrCreate(ctx context.Context) (*Conn, error)
}

// SaxConnectionFactory resolves backends via SAX admin server and connects to them in a round-robin fashion.
Expand All @@ -146,7 +162,7 @@ type SaxConnectionFactory struct {
}

// GetOrCreate selects a server and returns a connection to it.
func (f SaxConnectionFactory) GetOrCreate(ctx context.Context) (conn *grpc.ClientConn, err error) {
func (f SaxConnectionFactory) GetOrCreate(ctx context.Context) (conn *Conn, err error) {
addr, err := f.Location.Pick(ctx)
if err == nil {
conn, err = globalConnTable.getOrCreate(ctx, addr)
Expand All @@ -161,16 +177,19 @@ type DirectConnectionFactory struct {
connection *grpc.ClientConn
}

// GetOrCreate returns a connection and address of the model server.
func (f *DirectConnectionFactory) GetOrCreate(ctx context.Context) (conn *grpc.ClientConn, err error) {
// GetOrCreate returns a connection to the address of a model server.
func (f *DirectConnectionFactory) GetOrCreate(ctx context.Context) (conn *Conn, err error) {
// WithDefaultServiceConfig is required for MBNS. It will be ignored if the server is backed by GRPC.
f.mu.Lock()
defer f.mu.Unlock()
conn = f.connection
if conn == nil {
conn, err = env.Get().DialContext(ctx, f.Address,
if f.connection == nil {
connection, err := env.Get().DialContext(ctx, f.Address,
grpc.WithDefaultServiceConfig(`{"loadBalancingConfig": [{"round_robin":{}}]}`))
f.connection = conn
if err == nil {
f.connection = connection
} else {
return nil, err
}
}
return conn, err
return &Conn{client: f.connection}, nil
}
9 changes: 6 additions & 3 deletions saxml/client/go/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ func TestConnection(t *testing.T) {
if err != nil {
t.Fatalf("Creating connection for address %s failed with %v\n", addresses[j], err)
}
defer conn.Release()
req := &pb.ScoreRequest{
ModelKey: "m1",
Suffix: []string{"abc"},
Prefix: "xyz",
}
modelServer := pbgrpc.NewLMServiceClient(conn)
modelServer := pbgrpc.NewLMServiceClient(conn.Client())
res, err := modelServer.Score(context.Background(), req)
if err != nil {
t.Fatalf("Unable to Score() against address %s due to %v\n", addresses[j], err)
Expand Down Expand Up @@ -108,7 +109,8 @@ func TestBrokenConnection(t *testing.T) {
if err != nil {
t.Fatalf("Creating connection for address %s failed with %v\n", addr, err)
}
defer conn.Close()
defer conn.Client().Close()
defer conn.Release()

// Shut down the model server.
close(closer)
Expand All @@ -119,9 +121,10 @@ func TestBrokenConnection(t *testing.T) {
if err != nil {
t.Fatalf("Getting connection for address %s failed with %v\n", addr, err)
}
defer conn.Release()

// Attempting to use the connection should return an Unavailable error.
_, err = pbgrpc.NewLMServiceClient(conn).Score(ctx, &pb.ScoreRequest{})
_, err = pbgrpc.NewLMServiceClient(conn.Client()).Score(ctx, &pb.ScoreRequest{})
want := codes.Unavailable
if got := saxerrors.Code(err); got != want {
t.Errorf("Expect error code %v, got %v", want, got)
Expand Down
3 changes: 2 additions & 1 deletion saxml/client/go/sax.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ func (m *Model) run(ctx context.Context, methodName string, callMethod func(conn
makeQuery := func() error {
modelServerConn, err := m.connectionFactory.GetOrCreate(ctx)
if err == nil {
err = callMethod(modelServerConn)
err = callMethod(modelServerConn.Client())
modelServerConn.Release()
} else if errors.IsNotFound(err) {
// If the model does not exist anymore, no point to retry.
err = backoff.Permanent(err)
Expand Down

0 comments on commit af3f4e8

Please sign in to comment.