From bb9d9914a43edc8a245ef04ce57732ea434b6e38 Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Mon, 21 Oct 2024 11:14:46 -0700 Subject: [PATCH] Remove `SolveAuthenticateChallenge ` from MFA Ceremony in favor of `PromptConstructor` (#47650) * Remove mfa.Ceremony.SolveChallenge in favor of PromptFunc. * Update api/mfa/ceremony.go Co-authored-by: Marco Dinis --------- Co-authored-by: Marco Dinis --- api/mfa/ceremony.go | 19 +++----- api/mfa/ceremony_test.go | 40 ++++------------ lib/auth/helpers_mfa.go | 12 +++-- lib/client/presence.go | 26 ++++++----- lib/web/desktop.go | 99 +++++++++++++++++++++------------------- lib/web/kube.go | 11 +---- lib/web/terminal.go | 26 +++++------ 7 files changed, 100 insertions(+), 133 deletions(-) diff --git a/api/mfa/ceremony.go b/api/mfa/ceremony.go index f3c5f88e23d65..67b55e8fea379 100644 --- a/api/mfa/ceremony.go +++ b/api/mfa/ceremony.go @@ -28,14 +28,14 @@ import ( // Ceremony is an MFA ceremony. type Ceremony struct { // CreateAuthenticateChallenge creates an authentication challenge. - CreateAuthenticateChallenge func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) + CreateAuthenticateChallenge CreateAuthenticateChallengeFunc // PromptConstructor creates a prompt to prompt the user to solve an authentication challenge. PromptConstructor PromptConstructor - // SolveAuthenticateChallenge solves an authentication challenge. Used in non-interactive settings, - // such as the WebUI with layers abstracting user interaction, and tests. - SolveAuthenticateChallenge func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) } +// CreateAuthenticateChallengeFunc is a function that creates an authentication challenge. +type CreateAuthenticateChallengeFunc func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) + // Run the MFA ceremony. // // req may be nil if ceremony.CreateAuthenticateChallenge does not require it, e.g. in @@ -44,8 +44,6 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen switch { case c.CreateAuthenticateChallenge == nil: return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin") - case c.SolveAuthenticateChallenge != nil && c.PromptConstructor != nil: - return nil, trace.BadParameter("mfa ceremony should have SolveAuthenticateChallenge or PromptConstructor set, not both") case req == nil: // req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources // its own req or uses a different rpc, e.g. moderated sessions. @@ -72,13 +70,8 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen return nil, &ErrMFANotRequired } - if c.SolveAuthenticateChallenge == nil && c.PromptConstructor == nil { - return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have SolveAuthenticateChallenge or PromptConstructor set in order to succeed") - } - - if c.SolveAuthenticateChallenge != nil { - resp, err := c.SolveAuthenticateChallenge(ctx, chal) - return resp, trace.Wrap(err) + if c.PromptConstructor == nil { + return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have PromptConstructor set in order to succeed") } resp, err := c.PromptConstructor(promptOpts...).Run(ctx, chal) diff --git a/api/mfa/ceremony_test.go b/api/mfa/ceremony_test.go index bb6a24b6fcdbe..7d94fd4de5327 100644 --- a/api/mfa/ceremony_test.go +++ b/api/mfa/ceremony_test.go @@ -65,20 +65,6 @@ func TestPerformMFACeremony(t *testing.T) { assert.NoError(t, err) assert.Equal(t, testMFAResponse, mr) }, - }, { - name: "OK ceremony success solve", - ceremony: &mfa.Ceremony{ - CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) { - return testMFAChallenge, nil - }, - SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return testMFAResponse, nil - }, - }, - assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) { - assert.NoError(t, err) - assert.Equal(t, testMFAResponse, mr) - }, }, { name: "OK ceremony not required", ceremony: &mfa.Ceremony{ @@ -87,8 +73,10 @@ func TestPerformMFACeremony(t *testing.T) { MFARequired: proto.MFARequired_MFA_REQUIRED_NO, }, nil }, - SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return nil, trace.BadParameter("expected mfa not required") + PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt { + return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + return nil, trace.BadParameter("expected mfa not required") + }) }, }, assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) { @@ -101,8 +89,10 @@ func TestPerformMFACeremony(t *testing.T) { CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) { return nil, errors.New("create authenticate challenge failure") }, - SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return nil, trace.BadParameter("expected challenge failure") + PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt { + return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + return nil, trace.BadParameter("expected challenge failure") + }) }, }, assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) { @@ -125,20 +115,6 @@ func TestPerformMFACeremony(t *testing.T) { assert.ErrorContains(t, err, "prompt mfa failure") assert.Nil(t, mr) }, - }, { - name: "NOK solve mfa fail", - ceremony: &mfa.Ceremony{ - CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) { - return testMFAChallenge, nil - }, - SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return nil, errors.New("solve mfa failure") - }, - }, - assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) { - assert.ErrorContains(t, err, "solve mfa failure") - assert.Nil(t, mr) - }, }, } { t.Run(tt.name, func(t *testing.T) { diff --git a/lib/auth/helpers_mfa.go b/lib/auth/helpers_mfa.go index d41e5e6e95ac4..eca034791db4c 100644 --- a/lib/auth/helpers_mfa.go +++ b/lib/auth/helpers_mfa.go @@ -78,7 +78,8 @@ func NewTestDeviceFromChallenge(c *proto.MFARegisterChallenge, opts ...TestDevic // RegisterTestDevice creates and registers a TestDevice. // TOTP devices require a clock option. func RegisterTestDevice( - ctx context.Context, clt authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice, opts ...TestDeviceOpt) (*TestDevice, error) { + ctx context.Context, clt authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice, opts ...TestDeviceOpt, +) (*TestDevice, error) { dev := &TestDevice{} // Remaining parameters set during registration for _, opt := range opts { opt(dev) @@ -104,10 +105,12 @@ type authClientI interface { func (d *TestDevice) registerDevice(ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error { mfaCeremony := &mfa.Ceremony{ - CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge, - SolveAuthenticateChallenge: func(_ context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - return authenticator.SolveAuthn(chal) + PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt { + return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + return authenticator.SolveAuthn(chal) + }) }, + CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge, } authnSolved, err := mfaCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{ @@ -212,7 +215,6 @@ func (d *TestDevice) solveRegister(c *proto.MFARegisterChallenge) (*proto.MFAReg default: return nil, trace.BadParameter("unexpected challenge type: %T", c.Request) } - } func (d *TestDevice) solveRegisterWebauthn(c *proto.MFARegisterChallenge) (*proto.MFARegisterResponse, error) { diff --git a/lib/client/presence.go b/lib/client/presence.go index 14a8cb5d542f6..8ad63cc997bc6 100644 --- a/lib/client/presence.go +++ b/lib/client/presence.go @@ -77,6 +77,20 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai } mfaCeremony := &mfa.Ceremony{ + PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt { + return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n") + + mfaResp, err := mfaPrompt.Run(ctx, chal) + if err != nil { + fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err) + return nil, trace.Wrap(err) + } + + fmt.Fprint(term, "\r\nTeleport > Received MFA presence confirmation\r\n") + return mfaResp, nil + }) + }, CreateAuthenticateChallenge: func(ctx context.Context, _ *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) { req := &proto.PresenceMFAChallengeSend{ Request: &proto.PresenceMFAChallengeSend_ChallengeRequest{ @@ -99,18 +113,6 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai return challenge, nil }, - SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n") - - mfaResp, err := mfaPrompt.Run(ctx, chal) - if err != nil { - fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err) - return nil, trace.Wrap(err) - } - - fmt.Fprint(term, "\r\nTeleport > Received MFA presence confirmation\r\n") - return mfaResp, nil - }, } for { diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 5c27ce096481e..9173a7bf57fc0 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -370,65 +370,67 @@ func (h *Handler) performSessionMFACeremony( }() mfaCeremony := &mfa.Ceremony{ - CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge, - SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { - codec := tdpMFACodec{} - - // Send the challenge over the socket. - msg, err := codec.Encode( - &client.MFAAuthenticateChallenge{ - WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge), - }, - defaults.WebsocketWebauthnChallenge, - ) - if err != nil { - return nil, trace.Wrap(err) - } - - if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil { - return nil, trace.Wrap(err) - } - - span.AddEvent("waiting for user to complete mfa ceremony") - var buf []byte - // Loop through incoming messages until we receive an MFA message that lets us - // complete the ceremony. Non-MFA messages (e.g. ClientScreenSpecs representing - // screen resizes) are withheld for later. - for { - var ty int - ty, buf, err = ws.ReadMessage() + PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt { + return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) { + codec := tdpMFACodec{} + + // Send the challenge over the socket. + msg, err := codec.Encode( + &client.MFAAuthenticateChallenge{ + WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge), + }, + defaults.WebsocketWebauthnChallenge, + ) if err != nil { return nil, trace.Wrap(err) } - if ty != websocket.BinaryMessage { - return nil, trace.BadParameter("received unexpected web socket message type %d", ty) - } - if len(buf) == 0 { - return nil, trace.BadParameter("empty message received") + + if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil { + return nil, trace.Wrap(err) } - if tdp.MessageType(buf[0]) != tdp.TypeMFA { - // This is not an MFA message, withhold it for later. - msg, err := tdp.Decode(buf) - h.log.Debugf("Received non-MFA message, withholding:", msg) + span.AddEvent("waiting for user to complete mfa ceremony") + var buf []byte + // Loop through incoming messages until we receive an MFA message that lets us + // complete the ceremony. Non-MFA messages (e.g. ClientScreenSpecs representing + // screen resizes) are withheld for later. + for { + var ty int + ty, buf, err = ws.ReadMessage() if err != nil { return nil, trace.Wrap(err) } - *withheld = append(*withheld, msg) - continue - } + if ty != websocket.BinaryMessage { + return nil, trace.BadParameter("received unexpected web socket message type %d", ty) + } + if len(buf) == 0 { + return nil, trace.BadParameter("empty message received") + } - break - } + if tdp.MessageType(buf[0]) != tdp.TypeMFA { + // This is not an MFA message, withhold it for later. + msg, err := tdp.Decode(buf) + h.log.Debugf("Received non-MFA message, withholding:", msg) + if err != nil { + return nil, trace.Wrap(err) + } + *withheld = append(*withheld, msg) + continue + } - assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge) - if err != nil { - return nil, trace.Wrap(err) - } - span.AddEvent("mfa ceremony completed") + break + } - return assertion, nil + assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge) + if err != nil { + return nil, trace.Wrap(err) + } + span.AddEvent("mfa ceremony completed") + + return assertion, nil + }) }, + CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge, } _, newCerts, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{ @@ -475,7 +477,8 @@ type connector struct { // to any of the services or if it encounters an error that is not a connection problem. func (c *connector) connectToWindowsService( clusterName string, - desktopServiceIDs []string) (conn net.Conn, version string, err error) { + desktopServiceIDs []string, +) (conn net.Conn, version string, err error) { for _, id := range desktopServiceIDs { conn, ver, err := c.tryConnect(clusterName, id) if err != nil && !trace.IsConnectionProblem(err) { diff --git a/lib/web/kube.go b/lib/web/kube.go index 3c8a7c32d7602..aad3a0a25c817 100644 --- a/lib/web/kube.go +++ b/lib/web/kube.go @@ -41,7 +41,6 @@ import ( clientproto "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" - "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth/authclient" @@ -240,18 +239,10 @@ func (p *podHandler) handler(r *http.Request) error { Usage: clientproto.UserCertsRequest_Kubernetes, } - mfaCeremony := &mfa.Ceremony{ - CreateAuthenticateChallenge: p.sctx.cfg.RootClient.CreateAuthenticateChallenge, - SolveAuthenticateChallenge: func(ctx context.Context, chal *clientproto.MFAAuthenticateChallenge) (*clientproto.MFAAuthenticateResponse, error) { - assertion, err := mfaPrompt(stream.WSStream, protobufMFACodec{}).Run(ctx, chal) - return assertion, trace.Wrap(err) - }, - } - _, certs, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{ CurrentAuthClient: p.userClient, RootAuthClient: p.sctx.cfg.RootClient, - MFACeremony: mfaCeremony, + MFACeremony: newMFACeremony(stream.WSStream, p.sctx.cfg.RootClient.CreateAuthenticateChallenge), MFAAgainstRoot: p.sctx.cfg.RootClusterName == p.teleportCluster, MFARequiredReq: &clientproto.IsMFARequiredRequest{ Target: &clientproto.IsMFARequiredRequest_KubernetesCluster{KubernetesCluster: p.req.KubeCluster}, diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 2d8ca68a6ec0f..4727499ce281c 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -586,20 +586,10 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te SSHLogin: tc.HostLogin, } - mfaCeremony := &mfa.Ceremony{ - CreateAuthenticateChallenge: t.ctx.cfg.RootClient.CreateAuthenticateChallenge, - SolveAuthenticateChallenge: func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { - span.AddEvent("prompting user with mfa challenge") - assertion, err := mfaPrompt(wsStream, protobufMFACodec{}).Run(ctx, chal) - span.AddEvent("user completed mfa challenge") - return assertion, trace.Wrap(err) - }, - } - _, certs, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{ CurrentAuthClient: t.userAuthClient, RootAuthClient: t.ctx.cfg.RootClient, - MFACeremony: mfaCeremony, + MFACeremony: newMFACeremony(wsStream, t.ctx.cfg.RootClient.CreateAuthenticateChallenge), MFAAgainstRoot: t.ctx.cfg.RootClusterName == tc.SiteName, MFARequiredReq: mfaRequiredReq, CertsReq: certsReq, @@ -619,9 +609,19 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te return []ssh.AuthMethod{am}, nil } -func mfaPrompt(stream *terminal.WSStream, codec terminal.MFACodec) mfa.Prompt { +func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony { + return &mfa.Ceremony{ + CreateAuthenticateChallenge: createAuthenticateChallenge, + PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt { + return newMFAPrompt(stream) + }, + } +} + +func newMFAPrompt(stream *terminal.WSStream) mfa.Prompt { return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { var challenge *client.MFAAuthenticateChallenge + var codec protobufMFACodec // Convert from proto to JSON types. switch { @@ -797,7 +797,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor if t.participantMode == types.SessionModeratorMode { beforeStart = func(out io.Writer) { nc.OnMFA = func() { - if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), mfaPrompt(t.stream.WSStream, protobufMFACodec{})); err != nil { + if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream)); err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks") return }