Skip to content

Commit

Permalink
Add client implementation of SSO MFA for moderated sessions.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Nov 8, 2024
1 parent a2975c6 commit 94ce7c6
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 46 deletions.
3 changes: 0 additions & 3 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
switch {
case c.CreateAuthenticateChallenge == nil:
return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin")
case req == nil:
// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different rpc, e.g. moderated sessions.
case req.ChallengeExtensions == nil:
return nil, trace.BadParameter("missing challenge extensions")
case req.ChallengeExtensions.Scope == mfav1.ChallengeScope_CHALLENGE_SCOPE_UNSPECIFIED:
Expand Down
2 changes: 1 addition & 1 deletion lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2239,7 +2239,7 @@ func (tc *TeleportClient) Join(ctx context.Context, mode types.SessionParticipan
if mode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
RunPresenceTask(presenceCtx, out, clt.AuthClient, session.GetSessionID(), tc.NewMFAPrompt(mfa.WithQuiet()))
RunPresenceTask(presenceCtx, out, clt.AuthClient, session.GetSessionID(), tc.NewMFACeremony())
}
}
}
Expand Down
3 changes: 1 addition & 2 deletions lib/client/kubesession.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client/terminal"
"github.com/gravitational/teleport/lib/kube/proxy/streamproto"
Expand Down Expand Up @@ -196,7 +195,7 @@ func (s *KubeSession) handleMFA(ctx context.Context, tc *TeleportClient, mode ty
}

go func() {
RunPresenceTask(ctx, stdout, auth, s.meta.GetSessionID(), tc.NewMFAPrompt(mfa.WithQuiet()))
RunPresenceTask(ctx, stdout, auth, s.meta.GetSessionID(), tc.NewMFACeremony())
auth.Close()
clt.Close()
}()
Expand Down
28 changes: 21 additions & 7 deletions lib/client/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/api/client/proto"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
"github.com/gravitational/teleport/api/mfa"
)

Expand Down Expand Up @@ -56,7 +57,7 @@ func WithPresenceClock(clock clockwork.Clock) PresenceOption {

// RunPresenceTask periodically performs and MFA ceremony to detect that a user is
// still present and attentive.
func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...PresenceOption) error {
func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMaintainer, sessionID string, baseCeremony *mfa.Ceremony, opts ...PresenceOption) error {
fmt.Fprintf(term, "\r\nTeleport > MFA presence enabled\r\n")

o := &presenceOptions{
Expand All @@ -76,12 +77,16 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
return trace.Wrap(err)
}

mfaCeremony := &mfa.Ceremony{
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
presenceCeremony := &mfa.Ceremony{
SSOMFACeremonyConstructor: baseCeremony.SSOMFACeremonyConstructor,
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
// Replace normal output with terminal messages specific to moderated sessions.
opts = append(opts, mfa.WithQuiet())

fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n")

mfaResp, err := mfaPrompt.Run(ctx, chal)
mfaResp, err := baseCeremony.PromptConstructor(opts...).Run(ctx, chal)
if err != nil {
fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err)
return nil, trace.Wrap(err)
Expand All @@ -91,10 +96,13 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
return mfaResp, nil
})
},
CreateAuthenticateChallenge: func(ctx context.Context, _ *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
CreateAuthenticateChallenge: func(ctx context.Context, chalReq *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
req := &proto.PresenceMFAChallengeSend{
Request: &proto.PresenceMFAChallengeSend_ChallengeRequest{
ChallengeRequest: &proto.PresenceMFAChallengeRequest{SessionID: sessionID},
ChallengeRequest: &proto.PresenceMFAChallengeRequest{
SessionID: sessionID,
SSOClientRedirectURL: chalReq.SSOClientRedirectURL,
},
},
}

Expand All @@ -118,7 +126,13 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
for {
select {
case <-ticker.Chan():
mfaResp, err := mfaCeremony.Run(ctx, nil /* req is not needed for MaintainSessionPresence */)
mfaResp, err := presenceCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
// With the custom CreateAuthenticateChallenge method above, we don't actually
// need to provide the extensions here, but the ceremony expects it.
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_USER_SESSION,
},
})
if err != nil {
return trace.Wrap(err)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ type ProxySettingsGetter interface {

// PresenceChecker is a function that executes an mfa prompt to enforce
// that a user is present.
type PresenceChecker = func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...client.PresenceOption) error
type PresenceChecker = func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaCeremony *mfa.Ceremony, opts ...client.PresenceOption) error

// Config represents web handler configuration parameters
type Config struct {
Expand Down
4 changes: 2 additions & 2 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10004,8 +10004,8 @@ func TestModeratedSessionWithMFA(t *testing.T) {
RPID: RPID,
},
},
presenceChecker: func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaPrompt mfa.Prompt, opts ...client.PresenceOption) error {
return trace.Wrap(client.RunPresenceTask(ctx, term, maintainer, sessionID, mfaPrompt, client.WithPresenceClock(presenceClock)))
presenceChecker: func(ctx context.Context, term io.Writer, maintainer client.PresenceMaintainer, sessionID string, mfaCeremony *mfa.Ceremony, opts ...client.PresenceOption) error {
return trace.Wrap(client.RunPresenceTask(ctx, term, maintainer, sessionID, mfaCeremony, client.WithPresenceClock(presenceClock)))
},
})

Expand Down
57 changes: 27 additions & 30 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,11 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
}

func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony {
// channelID is used by the front end to differentiate between separate ongoing SSO challenges.
var channelID string

return &mfa.Ceremony{
CreateAuthenticateChallenge: createAuthenticateChallenge,
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return newMFAPrompt(stream, channelID)
},
SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) {
id, err := uuid.NewRandom()
if err != nil {
Expand All @@ -635,36 +634,33 @@ func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.C
ClientCallbackURL: u.String(),
}, nil
},
}
}

// create a new MFA prompt. When provided, ssoChannelID is used by the front end to differentiate
// between separate ongoing SSO challenges.
func newMFAPrompt(stream *terminal.WSStream, ssoChannelID string) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
// Convert from proto to JSON types.
var challenge client.MFAAuthenticateChallenge
if chal.WebauthnChallenge != nil {
challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge)
}
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
// Convert from proto to JSON types.
var challenge client.MFAAuthenticateChallenge
if chal.WebauthnChallenge != nil {
challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge)
}

if chal.SSOChallenge != nil {
challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge)
challenge.SSOChallenge.ChannelID = ssoChannelID
}
if chal.SSOChallenge != nil {
challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge)
challenge.SSOChallenge.ChannelID = channelID
}

if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil {
return nil, trace.AccessDenied("only WebAuthn and SSO MFA methods are supported on the web terminal, please register a supported mfa method to connect to this server")
}
if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil {
return nil, trace.AccessDenied("only WebAuthn and SSO MFA methods are supported on the web terminal, please register a supported mfa method to connect to this server")
}

var codec protobufMFACodec
if err := stream.WriteChallenge(&challenge, codec); err != nil {
return nil, trace.Wrap(err)
}
var codec protobufMFACodec
if err := stream.WriteChallenge(&challenge, codec); err != nil {
return nil, trace.Wrap(err)
}

resp, err := stream.ReadChallengeResponse(codec)
return resp, trace.Wrap(err)
})
resp, err := stream.ReadChallengeResponse(codec)
return resp, trace.Wrap(err)
})
},
}
}

type connectWithMFAFn = func(ctx context.Context, ws terminal.WSConn, tc *client.TeleportClient, accessChecker services.AccessChecker, getAgent teleagent.Getter, signer agentless.SignerCreator) (*client.NodeClient, error)
Expand Down Expand Up @@ -822,7 +818,8 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
if t.participantMode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream, "" /*ssoChannelID*/)); err != nil {
baseCeremony := newMFACeremony(t.stream.WSStream, nil)
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), baseCeremony); err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks")
return
}
Expand Down

0 comments on commit 94ce7c6

Please sign in to comment.