Skip to content

Commit

Permalink
feat: SSO MFA - Moderated Sessions (#48646)
Browse files Browse the repository at this point in the history
* Add sso client redirect url to maintain session rpc.

* Handle sso client redirect URL in maintain session presence RPC.

* Add client implementation of SSO MFA for moderated sessions.

* Cleanup.
  • Loading branch information
Joerger authored Nov 12, 2024
1 parent 4750262 commit 169128d
Show file tree
Hide file tree
Showing 10 changed files with 1,058 additions and 1,002 deletions.
1,945 changes: 999 additions & 946 deletions api/client/proto/authservice.pb.go

Large diffs are not rendered by default.

18 changes: 9 additions & 9 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,8 @@ type CreateAuthenticateChallengeFunc func(ctx context.Context, req *proto.Create
// req may be nil if ceremony.CreateAuthenticateChallenge does not require it, e.g. in
// the moderated session mfa ceremony which uses a custom stream rpc to create challenges.
func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error) {
switch {
case c.CreateAuthenticateChallenge == nil:
if c.CreateAuthenticateChallenge == nil {
return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin")
case req == nil:
// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different rpc, e.g. moderated sessions.
case req.ChallengeExtensions == nil:
return nil, trace.BadParameter("missing challenge extensions")
case req.ChallengeExtensions.Scope == mfav1.ChallengeScope_CHALLENGE_SCOPE_UNSPECIFIED:
return nil, trace.BadParameter("mfa challenge scope must be specified")
}

// If available, prepare an SSO MFA ceremony and set the client redirect URL in the challenge
Expand All @@ -78,6 +70,14 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
slog.DebugContext(ctx, "Failed to attempt SSO MFA, continuing with other MFA methods", "error", err)
} else {
defer ssoMFACeremony.Close()

// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different e.g. login. We should still provide the sso client
// redirect URL in case the custom CreateAuthenticateChallenge handles it.
if req == nil {
req = new(proto.CreateAuthenticateChallengeRequest)
}

req.SSOClientRedirectURL = ssoMFACeremony.GetClientCallbackURL()
promptOpts = append(promptOpts, withSSOMFACeremony(ssoMFACeremony))
}
Expand Down
3 changes: 3 additions & 0 deletions api/proto/teleport/legacy/client/proto/authservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,9 @@ message UpdateSessionTrackerRequest {
message PresenceMFAChallengeRequest {
// SessionID is unique identifier of the session you want to request presence for.
string SessionID = 1 [(gogoproto.jsontag) = "session_id,omitempty"];
// SSOClientRedirectURL should be supplied If the client supports SSO MFA checks.
// If unset, the server will only return non-SSO challenges.
string SSOClientRedirectURL = 2 [(gogoproto.jsontag) = "sso_client_redirect_url,omitempty"];
}

// PresenceMFAChallengeSend is a presence challenge request or response.
Expand Down
5 changes: 1 addition & 4 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2167,11 +2167,8 @@ func (g *GRPCServer) DeleteRole(ctx context.Context, req *authpb.DeleteRoleReque
func doMFAPresenceChallenge(ctx context.Context, actx *grpcContext, stream authpb.AuthService_MaintainSessionPresenceServer, challengeReq *authpb.PresenceMFAChallengeRequest) error {
user := actx.User.GetName()

// TODO(Joerger): Extend SSO MFA support for moderated sessions.
var ssoClientRedirectURL string

chalExt := &mfav1pb.ChallengeExtensions{Scope: mfav1pb.ChallengeScope_CHALLENGE_SCOPE_USER_SESSION}
authChallenge, err := actx.authServer.mfaAuthChallenge(ctx, user, ssoClientRedirectURL, chalExt)
authChallenge, err := actx.authServer.mfaAuthChallenge(ctx, user, challengeReq.SSOClientRedirectURL, chalExt)
if err != nil {
return trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2239,7 +2239,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan
if mode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
RunPresenceTask(presenceCtx, out, clt.AuthClient, session.GetSessionID(), tc.NewMFAPrompt(mfa.WithQuiet()))
RunPresenceTask(presenceCtx, out, clt.AuthClient, session.GetSessionID(), tc.NewMFACeremony())
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions lib/client/kubesession.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client/terminal"
"github.com/gravitational/teleport/lib/kube/proxy/streamproto"
Expand Down Expand Up @@ -196,7 +195,7 @@ func (s *KubeSession) handleMFA(ctx context.Context, tc *TeleportClient, mode ty
}

go func() {
RunPresenceTask(ctx, stdout, auth, s.meta.GetSessionID(), tc.NewMFAPrompt(mfa.WithQuiet()))
RunPresenceTask(ctx, stdout, auth, s.meta.GetSessionID(), tc.NewMFACeremony())
auth.Close()
clt.Close()
}()
Expand Down
21 changes: 14 additions & 7 deletions lib/client/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func WithPresenceClock(clock clockwork.Clock) PresenceOption {

// RunPresenceTask periodically performs and MFA ceremony to detect that a user is
// still present and attentive.
func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...PresenceOption) error {
func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMaintainer, sessionID string, baseCeremony *mfa.Ceremony, opts ...PresenceOption) error {
fmt.Fprintf(term, "\r\nTeleport > MFA presence enabled\r\n")

o := &presenceOptions{
Expand All @@ -76,12 +76,16 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
return trace.Wrap(err)
}

mfaCeremony := &mfa.Ceremony{
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
presenceCeremony := &mfa.Ceremony{
SSOMFACeremonyConstructor: baseCeremony.SSOMFACeremonyConstructor,
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
// Replace normal output with terminal messages specific to moderated sessions.
opts = append(opts, mfa.WithQuiet())

fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n")

mfaResp, err := mfaPrompt.Run(ctx, chal)
mfaResp, err := baseCeremony.PromptConstructor(opts...).Run(ctx, chal)
if err != nil {
fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err)
return nil, trace.Wrap(err)
Expand All @@ -91,10 +95,13 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
return mfaResp, nil
})
},
CreateAuthenticateChallenge: func(ctx context.Context, _ *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
CreateAuthenticateChallenge: func(ctx context.Context, chalReq *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
req := &proto.PresenceMFAChallengeSend{
Request: &proto.PresenceMFAChallengeSend_ChallengeRequest{
ChallengeRequest: &proto.PresenceMFAChallengeRequest{SessionID: sessionID},
ChallengeRequest: &proto.PresenceMFAChallengeRequest{
SessionID: sessionID,
SSOClientRedirectURL: chalReq.SSOClientRedirectURL,
},
},
}

Expand All @@ -118,7 +125,7 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
for {
select {
case <-ticker.Chan():
mfaResp, err := mfaCeremony.Run(ctx, nil /* req is not needed for MaintainSessionPresence */)
mfaResp, err := presenceCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{})
if err != nil {
return trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ type ProxySettingsGetter interface {

// PresenceChecker is a function that executes an mfa prompt to enforce
// that a user is present.
type PresenceChecker = func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...client.PresenceOption) error
type PresenceChecker = func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaCeremony *mfa.Ceremony, opts ...client.PresenceOption) error

// Config represents web handler configuration parameters
type Config struct {
Expand Down
4 changes: 2 additions & 2 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10004,8 +10004,8 @@ func TestModeratedSessionWithMFA(t *testing.T) {
RPID: RPID,
},
},
presenceChecker: func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...client.PresenceOption) error {
return trace.Wrap(client.RunPresenceTask(ctx, term, maintainer, sessionID, mfaPrompt, client.WithPresenceClock(presenceClock)))
presenceChecker: func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaCeremony *mfa.Ceremony, opts ...client.PresenceOption) error {
return trace.Wrap(client.RunPresenceTask(ctx, term, maintainer, sessionID, mfaCeremony, client.WithPresenceClock(presenceClock)))
},
})

Expand Down
57 changes: 27 additions & 30 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,11 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
}

func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony {
// channelID is used by the front end to differentiate between separate ongoing SSO challenges.
var channelID string

return &mfa.Ceremony{
CreateAuthenticateChallenge: createAuthenticateChallenge,
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return newMFAPrompt(stream, channelID)
},
SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) {
id, err := uuid.NewRandom()
if err != nil {
Expand All @@ -635,36 +634,33 @@ func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.C
ClientCallbackURL: u.String(),
}, nil
},
}
}

// create a new MFA prompt. When provided, ssoChannelID is used by the front end to differentiate
// between separate ongoing SSO challenges.
func newMFAPrompt(stream *terminal.WSStream, ssoChannelID string) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
// Convert from proto to JSON types.
var challenge client.MFAAuthenticateChallenge
if chal.WebauthnChallenge != nil {
challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge)
}
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
// Convert from proto to JSON types.
var challenge client.MFAAuthenticateChallenge
if chal.WebauthnChallenge != nil {
challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge)
}

if chal.SSOChallenge != nil {
challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge)
challenge.SSOChallenge.ChannelID = ssoChannelID
}
if chal.SSOChallenge != nil {
challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge)
challenge.SSOChallenge.ChannelID = channelID
}

if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil {
return nil, trace.AccessDenied("only WebAuthn and SSO MFA methods are supported on the web terminal, please register a supported mfa method to connect to this server")
}
if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil {
return nil, trace.AccessDenied("only WebAuthn and SSO MFA methods are supported on the web terminal, please register a supported mfa method to connect to this server")
}

var codec protobufMFACodec
if err := stream.WriteChallenge(&challenge, codec); err != nil {
return nil, trace.Wrap(err)
}
var codec protobufMFACodec
if err := stream.WriteChallenge(&challenge, codec); err != nil {
return nil, trace.Wrap(err)
}

resp, err := stream.ReadChallengeResponse(codec)
return resp, trace.Wrap(err)
})
resp, err := stream.ReadChallengeResponse(codec)
return resp, trace.Wrap(err)
})
},
}
}

type connectWithMFAFn = func(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error)
Expand Down Expand Up @@ -822,7 +818,8 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
if t.participantMode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream, "" /*ssoChannelID*/)); err != nil {
baseCeremony := newMFACeremony(t.stream.WSStream, nil)
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), baseCeremony); err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks")
return
}
Expand Down

0 comments on commit 169128d

Please sign in to comment.