diff --git a/lib/srv/usermgmt.go b/lib/srv/usermgmt.go index c73e1db41390e..9415c7148096c 100644 --- a/lib/srv/usermgmt.go +++ b/lib/srv/usermgmt.go @@ -485,6 +485,8 @@ func (u *HostUserManagement) UpsertUser(name string, ui services.HostUsersInfo) return closer, nil } +const userLeaseDuration = time.Second * 20 + func (u *HostUserManagement) doWithUserLock(f func(types.SemaphoreLease) error) error { lock, err := services.AcquireSemaphoreWithRetry(u.ctx, services.AcquireSemaphoreWithRetryConfig{ @@ -493,7 +495,7 @@ func (u *HostUserManagement) doWithUserLock(f func(types.SemaphoreLease) error) SemaphoreKind: types.SemaphoreKindHostUserModification, SemaphoreName: "host_user_modification", MaxLeases: 1, - Expires: time.Now().Add(time.Second * 20), + Expires: time.Now().Add(userLeaseDuration), }, Retry: retryutils.LinearConfig{ Step: time.Second * 5, @@ -556,26 +558,35 @@ func (u *HostUserManagement) DeleteAllUsers() error { return trace.Wrap(err) } var errs []error - for _, name := range users { - lt, err := u.storage.GetHostUserInteractionTime(u.ctx, name) - if err != nil { - u.log.DebugContext(u.ctx, "Failed to find user login time", "host_username", name, "error", err) - continue - } - u.doWithUserLock(func(l types.SemaphoreLease) error { + u.doWithUserLock(func(l types.SemaphoreLease) error { + for _, name := range users { + if time.Until(l.Expires) < userLeaseDuration/2 { + l.Expires = time.Now().Add(userLeaseDuration / 2) + if err := u.storage.KeepAliveSemaphoreLease(u.ctx, l); err != nil { + u.log.DebugContext(u.ctx, "Failed to keep alive host user lease", "error", err) + } + } + + lt, err := u.storage.GetHostUserInteractionTime(u.ctx, name) + if err != nil { + u.log.DebugContext(u.ctx, "Failed to find user login time", "host_username", name, "error", err) + continue + } + if time.Since(lt) < u.userGrace { // small grace period in order to avoid deleting users // in-between them starting their SSH session and // entering the shell - return nil + continue } - errs = append(errs, u.DeleteUser(name, teleportGroup.Gid)) - l.Expires = time.Now().Add(time.Second * 10) - u.storage.KeepAliveSemaphoreLease(u.ctx, l) - return nil - }) - } + if err := u.DeleteUser(name, teleportGroup.Gid); err != nil { + errs = append(errs, err) + } + } + + return nil + }) return trace.NewAggregate(errs...) }