Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] wait after spanner test client is disconnected #49722

Merged
merged 2 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions lib/srv/db/access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"testing"
"time"

gspanner "cloud.google.com/go/spanner"
"github.com/ClickHouse/ch-go"
cqlclient "github.com/datastax/go-cassandra-native-protocol/client"
elastic "github.com/elastic/go-elasticsearch/v8"
Expand Down Expand Up @@ -2133,7 +2132,7 @@ func (c *testContext) dynamodbClient(ctx context.Context, teleportUser, dbServic
return db, proxy, nil
}

func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*gspanner.Client, *alpnproxy.LocalProxy, error) {
func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService, dbUser, dbName string) (*spanner.SpannerTestClient, *alpnproxy.LocalProxy, error) {
route := tlsca.RouteToDatabase{
ServiceName: dbService,
Protocol: defaults.ProtocolSpanner,
Expand All @@ -2146,7 +2145,7 @@ func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService
return nil, nil, trace.Wrap(err)
}

db, err := spanner.MakeTestClient(ctx, common.TestClientConfig{
clt, err := spanner.MakeTestClient(ctx, common.TestClientConfig{
AuthClient: c.authClient,
AuthServer: c.authServer,
Address: proxy.GetAddr(),
Expand All @@ -2158,7 +2157,7 @@ func (c *testContext) spannerClient(ctx context.Context, teleportUser, dbService
return nil, nil, trace.Wrap(err)
}

return db, proxy, nil
return clt, proxy, nil
}

type roleOptFn func(types.Role)
Expand Down
46 changes: 37 additions & 9 deletions lib/srv/db/spanner/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/sirupsen/logrus"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
Expand All @@ -47,11 +48,35 @@ import (
"github.com/gravitational/teleport/lib/tlsca"
)

func MakeTestClient(ctx context.Context, config common.TestClientConfig) (*spanner.Client, error) {
// SpannerTestClient wraps a [spanner.Client] and provides direct access to the
// underlying [grpc.ClientConn] of the client.
type SpannerTestClient struct {
ClientConn *grpc.ClientConn
*spanner.Client
}

// WaitForConnectionState waits until the spanner client's underlying gRPC
// connection transitions into the given state or the context expires.
func (c *SpannerTestClient) WaitForConnectionState(ctx context.Context, wantState connectivity.State) error {
for {
s := c.ClientConn.GetState()
if s == wantState {
return nil
}
if s == connectivity.Shutdown {
return trace.Errorf("spanner test client connection has shutdown")
}
if !c.ClientConn.WaitForStateChange(ctx, s) {
return ctx.Err()
}
}
}

func MakeTestClient(ctx context.Context, config common.TestClientConfig) (*SpannerTestClient, error) {
return makeTestClient(ctx, config, false)
}

func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS bool) (*spanner.Client, error) {
func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS bool) (*SpannerTestClient, error) {
databaseID, err := getDatabaseID(ctx, config.RouteToDatabase, config.AuthServer)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -68,13 +93,13 @@ func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS
transportOpt = grpc.WithTransportCredentials(insecure.NewCredentials())
}

cc, err := grpc.NewClient(config.Address, transportOpt)
if err != nil {
return nil, trace.Wrap(err)
}

opts := []option.ClientOption{
// dial with custom transport security
option.WithGRPCDialOption(transportOpt),
// create 1 connection
option.WithGRPCConnectionPool(1),
// connect to the Teleport endpoint
option.WithEndpoint(config.Address),
option.WithGRPCConn(cc),
// client should not bring any GCP credentials
option.WithoutAuthentication(),
}
Expand All @@ -86,7 +111,10 @@ func makeTestClient(ctx context.Context, config common.TestClientConfig, useTLS
if err != nil {
return nil, trace.Wrap(err)
}
return clt, nil
return &SpannerTestClient{
ClientConn: cc,
Client: clt,
}, nil
}

func getDatabaseID(ctx context.Context, route tlsca.RouteToDatabase, getter services.DatabaseServersGetter) (string, error) {
Expand Down
14 changes: 12 additions & 2 deletions lib/srv/db/spanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (

gspanner "cloud.google.com/go/spanner"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/connectivity"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
Expand Down Expand Up @@ -234,7 +235,15 @@ func TestAuditSpanner(t *testing.T) {
_ = localProxy.Close()
})

require.NoError(t, err)
require.NoError(t, clt.WaitForConnectionState(ctx, connectivity.Ready))
reconnectingCh := make(chan bool)
go func() {
// we should observe the connection leave the "ready" state after
// it gets an access denied error.
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
reconnectingCh <- clt.ClientConn.WaitForStateChange(ctx, connectivity.Ready)
}()

row, err := pingSpanner(ctx, clt, 42)
require.Error(t, err)
Expand All @@ -246,6 +255,7 @@ func TestAuditSpanner(t *testing.T) {
require.True(t, ok)
require.Equal(t, "googlesql", dbStart1.DatabaseName)

require.True(t, <-reconnectingCh, "timed out waiting for the spanner client to reconnect")
row, err = pingSpanner(ctx, clt, 42)
require.Error(t, err)
require.ErrorContains(t, err, "access to db denied")
Expand Down Expand Up @@ -308,7 +318,7 @@ func TestAuditSpanner(t *testing.T) {
})
}

func pingSpanner(ctx context.Context, clt *gspanner.Client, want int64) (*gspanner.Row, error) {
func pingSpanner(ctx context.Context, clt *spanner.SpannerTestClient, want int64) (*gspanner.Row, error) {
query := gspanner.NewStatement(fmt.Sprintf("SELECT %d", want))
rowIter := clt.Single().Query(ctx, query)
defer rowIter.Stop()
Expand Down
Loading