diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index bdd8c10b8838c..028adfb30db78 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -2016,6 +2016,14 @@ func withDeniedDBLabels(labels types.Labels) roleOptFn { } } +func withClientIdleTimeout(clientIdleTimeout time.Duration) roleOptFn { + return func(role types.Role) { + opts := role.GetOptions() + opts.ClientIdleTimeout = types.NewDuration(clientIdleTimeout) + role.SetOptions(opts) + } +} + // createUserAndRole creates Teleport user and role with specified names // and allowed database users/names properties. func (c *testContext) createUserAndRole(ctx context.Context, t testing.TB, userName, roleName string, dbUsers, dbNames []string, roleOpts ...roleOptFn) (types.User, types.Role) { diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index a829391896771..8224dea33d38b 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -980,6 +980,14 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro } }() }() + + // Wrap a client connection into monitor that auto-terminates + // idle connection and connection with expired cert. + ctx, clientConn, err = s.cfg.ConnectionMonitor.MonitorConn(cancelCtx, sessionCtx.AuthContext, clientConn) + if err != nil { + return trace.Wrap(err) + } + engine, err := s.dispatch(sessionCtx, rec, clientConn) if err != nil { return trace.Wrap(err) @@ -995,13 +1003,6 @@ func (s *Server) handleConnection(ctx context.Context, clientConn net.Conn) erro } }() - // Wrap a client connection into monitor that auto-terminates - // idle connection and connection with expired cert. - ctx, clientConn, err = s.cfg.ConnectionMonitor.MonitorConn(cancelCtx, sessionCtx.AuthContext, clientConn) - if err != nil { - return trace.Wrap(err) - } - // TODO(jakule): LoginIP should be required starting from 10.0. clientIP := sessionCtx.Identity.LoginIP if clientIP != "" { diff --git a/lib/srv/db/server_test.go b/lib/srv/db/server_test.go index 79a93c04dde8d..03bfbe884a00d 100644 --- a/lib/srv/db/server_test.go +++ b/lib/srv/db/server_test.go @@ -32,6 +32,7 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/services" ) @@ -184,6 +185,57 @@ func TestDatabaseServerLimiting(t *testing.T) { }) } +func TestDatabaseServerAutoDisconnect(t *testing.T) { + const ( + user = "bob" + role = "admin" + dbName = "postgres" + dbUser = user + ) + + ctx := context.Background() + allowDbUsers := []string{types.Wildcard} + allowDbNames := []string{types.Wildcard} + + testCtx := setupTestContext(ctx, t, withSelfHostedPostgres("postgres")) + + go testCtx.startHandlingConnections() + t.Cleanup(func() { + require.NoError(t, testCtx.Close()) + }) + + const clientIdleTimeout = time.Second * 30 + + // create user/role with client idle timeout + testCtx.createUserAndRole(ctx, t, user, role, allowDbUsers, allowDbNames, withClientIdleTimeout(clientIdleTimeout)) + + // connect + pgConn, err := testCtx.postgresClient(ctx, user, "postgres", dbUser, dbName) + require.NoError(t, err) + + // immediate query should work + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoError(t, err) + + // advance clock several times, perform query. + // the activity should update the idle activity timer. + for i := 0; i < 10; i++ { + testCtx.clock.Advance(clientIdleTimeout / 2) + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.NoErrorf(t, err, "failed on iteration %v", i+1) + } + + // advance clock by full idle timeout, expect the client to be disconnected automatically. + testCtx.clock.Advance(clientIdleTimeout) + waitForEvent(t, testCtx, events.ClientDisconnectCode) + + // expect failure after timeout. + _, err = pgConn.Exec(ctx, "select 1").ReadAll() + require.Error(t, err) + + require.NoError(t, pgConn.Close(ctx)) +} + func TestHeartbeatEvents(t *testing.T) { ctx := context.Background()