Skip to content

Commit

Permalink
Remove SolveAuthenticateChallenge from MFA Ceremony in favor of `P…
Browse files Browse the repository at this point in the history
…romptConstructor` (#47650)

* Remove mfa.Ceremony.SolveChallenge in favor of PromptFunc.

* Update api/mfa/ceremony.go

Co-authored-by: Marco Dinis <[email protected]>

---------

Co-authored-by: Marco Dinis <[email protected]>
  • Loading branch information
Joerger and marcoandredinis authored Oct 21, 2024
1 parent 23a245e commit bb9d991
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 133 deletions.
19 changes: 6 additions & 13 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import (
// Ceremony is an MFA ceremony.
type Ceremony struct {
// CreateAuthenticateChallenge creates an authentication challenge.
CreateAuthenticateChallenge func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)
CreateAuthenticateChallenge CreateAuthenticateChallengeFunc
// PromptConstructor creates a prompt to prompt the user to solve an authentication challenge.
PromptConstructor PromptConstructor
// SolveAuthenticateChallenge solves an authentication challenge. Used in non-interactive settings,
// such as the WebUI with layers abstracting user interaction, and tests.
SolveAuthenticateChallenge func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
}

// CreateAuthenticateChallengeFunc is a function that creates an authentication challenge.
type CreateAuthenticateChallengeFunc func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)

// Run the MFA ceremony.
//
// req may be nil if ceremony.CreateAuthenticateChallenge does not require it, e.g. in
Expand All @@ -44,8 +44,6 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
switch {
case c.CreateAuthenticateChallenge == nil:
return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin")
case c.SolveAuthenticateChallenge != nil && c.PromptConstructor != nil:
return nil, trace.BadParameter("mfa ceremony should have SolveAuthenticateChallenge or PromptConstructor set, not both")
case req == nil:
// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different rpc, e.g. moderated sessions.
Expand All @@ -72,13 +70,8 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
return nil, &ErrMFANotRequired
}

if c.SolveAuthenticateChallenge == nil && c.PromptConstructor == nil {
return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have SolveAuthenticateChallenge or PromptConstructor set in order to succeed")
}

if c.SolveAuthenticateChallenge != nil {
resp, err := c.SolveAuthenticateChallenge(ctx, chal)
return resp, trace.Wrap(err)
if c.PromptConstructor == nil {
return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have PromptConstructor set in order to succeed")
}

resp, err := c.PromptConstructor(promptOpts...).Run(ctx, chal)
Expand Down
40 changes: 8 additions & 32 deletions api/mfa/ceremony_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,6 @@ func TestPerformMFACeremony(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony success solve",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony not required",
ceremony: &mfa.Ceremony{
Expand All @@ -87,8 +73,10 @@ func TestPerformMFACeremony(t *testing.T) {
MFARequired: proto.MFARequired_MFA_REQUIRED_NO,
}, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected mfa not required")
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected mfa not required")
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
Expand All @@ -101,8 +89,10 @@ func TestPerformMFACeremony(t *testing.T) {
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return nil, errors.New("create authenticate challenge failure")
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected challenge failure")
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected challenge failure")
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
Expand All @@ -125,20 +115,6 @@ func TestPerformMFACeremony(t *testing.T) {
assert.ErrorContains(t, err, "prompt mfa failure")
assert.Nil(t, mr)
},
}, {
name: "NOK solve mfa fail",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, errors.New("solve mfa failure")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "solve mfa failure")
assert.Nil(t, mr)
},
},
} {
t.Run(tt.name, func(t *testing.T) {
Expand Down
12 changes: 7 additions & 5 deletions lib/auth/helpers_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ func NewTestDeviceFromChallenge(c *proto.MFARegisterChallenge, opts ...TestDevic
// RegisterTestDevice creates and registers a TestDevice.
// TOTP devices require a clock option.
func RegisterTestDevice(
ctx context.Context, clt authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice, opts ...TestDeviceOpt) (*TestDevice, error) {
ctx context.Context, clt authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice, opts ...TestDeviceOpt,
) (*TestDevice, error) {
dev := &TestDevice{} // Remaining parameters set during registration
for _, opt := range opts {
opt(dev)
Expand All @@ -104,10 +105,12 @@ type authClientI interface {

func (d *TestDevice) registerDevice(ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error {
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(_ context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return authenticator.SolveAuthn(chal)
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return authenticator.SolveAuthn(chal)
})
},
CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge,
}

authnSolved, err := mfaCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
Expand Down Expand Up @@ -212,7 +215,6 @@ func (d *TestDevice) solveRegister(c *proto.MFARegisterChallenge) (*proto.MFAReg
default:
return nil, trace.BadParameter("unexpected challenge type: %T", c.Request)
}

}

func (d *TestDevice) solveRegisterWebauthn(c *proto.MFARegisterChallenge) (*proto.MFARegisterResponse, error) {
Expand Down
26 changes: 14 additions & 12 deletions lib/client/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
}

mfaCeremony := &mfa.Ceremony{
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n")

mfaResp, err := mfaPrompt.Run(ctx, chal)
if err != nil {
fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err)
return nil, trace.Wrap(err)
}

fmt.Fprint(term, "\r\nTeleport > Received MFA presence confirmation\r\n")
return mfaResp, nil
})
},
CreateAuthenticateChallenge: func(ctx context.Context, _ *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
req := &proto.PresenceMFAChallengeSend{
Request: &proto.PresenceMFAChallengeSend_ChallengeRequest{
Expand All @@ -99,18 +113,6 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai

return challenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n")

mfaResp, err := mfaPrompt.Run(ctx, chal)
if err != nil {
fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err)
return nil, trace.Wrap(err)
}

fmt.Fprint(term, "\r\nTeleport > Received MFA presence confirmation\r\n")
return mfaResp, nil
},
}

for {
Expand Down
99 changes: 51 additions & 48 deletions lib/web/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,65 +370,67 @@ func (h *Handler) performSessionMFACeremony(
}()

mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
codec := tdpMFACodec{}

// Send the challenge over the socket.
msg, err := codec.Encode(
&client.MFAAuthenticateChallenge{
WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge),
},
defaults.WebsocketWebauthnChallenge,
)
if err != nil {
return nil, trace.Wrap(err)
}

if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return nil, trace.Wrap(err)
}

span.AddEvent("waiting for user to complete mfa ceremony")
var buf []byte
// Loop through incoming messages until we receive an MFA message that lets us
// complete the ceremony. Non-MFA messages (e.g. ClientScreenSpecs representing
// screen resizes) are withheld for later.
for {
var ty int
ty, buf, err = ws.ReadMessage()
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
codec := tdpMFACodec{}

// Send the challenge over the socket.
msg, err := codec.Encode(
&client.MFAAuthenticateChallenge{
WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge),
},
defaults.WebsocketWebauthnChallenge,
)
if err != nil {
return nil, trace.Wrap(err)
}
if ty != websocket.BinaryMessage {
return nil, trace.BadParameter("received unexpected web socket message type %d", ty)
}
if len(buf) == 0 {
return nil, trace.BadParameter("empty message received")

if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return nil, trace.Wrap(err)
}

if tdp.MessageType(buf[0]) != tdp.TypeMFA {
// This is not an MFA message, withhold it for later.
msg, err := tdp.Decode(buf)
h.log.Debugf("Received non-MFA message, withholding:", msg)
span.AddEvent("waiting for user to complete mfa ceremony")
var buf []byte
// Loop through incoming messages until we receive an MFA message that lets us
// complete the ceremony. Non-MFA messages (e.g. ClientScreenSpecs representing
// screen resizes) are withheld for later.
for {
var ty int
ty, buf, err = ws.ReadMessage()
if err != nil {
return nil, trace.Wrap(err)
}
*withheld = append(*withheld, msg)
continue
}
if ty != websocket.BinaryMessage {
return nil, trace.BadParameter("received unexpected web socket message type %d", ty)
}
if len(buf) == 0 {
return nil, trace.BadParameter("empty message received")
}

break
}
if tdp.MessageType(buf[0]) != tdp.TypeMFA {
// This is not an MFA message, withhold it for later.
msg, err := tdp.Decode(buf)
h.log.Debugf("Received non-MFA message, withholding:", msg)
if err != nil {
return nil, trace.Wrap(err)
}
*withheld = append(*withheld, msg)
continue
}

assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge)
if err != nil {
return nil, trace.Wrap(err)
}
span.AddEvent("mfa ceremony completed")
break
}

return assertion, nil
assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge)
if err != nil {
return nil, trace.Wrap(err)
}
span.AddEvent("mfa ceremony completed")

return assertion, nil
})
},
CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge,
}

_, newCerts, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
Expand Down Expand Up @@ -475,7 +477,8 @@ type connector struct {
// to any of the services or if it encounters an error that is not a connection problem.
func (c *connector) connectToWindowsService(
clusterName string,
desktopServiceIDs []string) (conn net.Conn, version string, err error) {
desktopServiceIDs []string,
) (conn net.Conn, version string, err error) {
for _, id := range desktopServiceIDs {
conn, ver, err := c.tryConnect(clusterName, id)
if err != nil && !trace.IsConnectionProblem(err) {
Expand Down
11 changes: 1 addition & 10 deletions lib/web/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import (

clientproto "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth/authclient"
Expand Down Expand Up @@ -240,18 +239,10 @@ func (p *podHandler) handler(r *http.Request) error {
Usage: clientproto.UserCertsRequest_Kubernetes,
}

mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: p.sctx.cfg.RootClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(ctx context.Context, chal *clientproto.MFAAuthenticateChallenge) (*clientproto.MFAAuthenticateResponse, error) {
assertion, err := mfaPrompt(stream.WSStream, protobufMFACodec{}).Run(ctx, chal)
return assertion, trace.Wrap(err)
},
}

_, certs, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
CurrentAuthClient: p.userClient,
RootAuthClient: p.sctx.cfg.RootClient,
MFACeremony: mfaCeremony,
MFACeremony: newMFACeremony(stream.WSStream, p.sctx.cfg.RootClient.CreateAuthenticateChallenge),
MFAAgainstRoot: p.sctx.cfg.RootClusterName == p.teleportCluster,
MFARequiredReq: &clientproto.IsMFARequiredRequest{
Target: &clientproto.IsMFARequiredRequest_KubernetesCluster{KubernetesCluster: p.req.KubeCluster},
Expand Down
26 changes: 13 additions & 13 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -586,20 +586,10 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
SSHLogin: tc.HostLogin,
}

mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: t.ctx.cfg.RootClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
span.AddEvent("prompting user with mfa challenge")
assertion, err := mfaPrompt(wsStream, protobufMFACodec{}).Run(ctx, chal)
span.AddEvent("user completed mfa challenge")
return assertion, trace.Wrap(err)
},
}

_, certs, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
CurrentAuthClient: t.userAuthClient,
RootAuthClient: t.ctx.cfg.RootClient,
MFACeremony: mfaCeremony,
MFACeremony: newMFACeremony(wsStream, t.ctx.cfg.RootClient.CreateAuthenticateChallenge),
MFAAgainstRoot: t.ctx.cfg.RootClusterName == tc.SiteName,
MFARequiredReq: mfaRequiredReq,
CertsReq: certsReq,
Expand All @@ -619,9 +609,19 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
return []ssh.AuthMethod{am}, nil
}

func mfaPrompt(stream *terminal.WSStream, codec terminal.MFACodec) mfa.Prompt {
func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony {
return &mfa.Ceremony{
CreateAuthenticateChallenge: createAuthenticateChallenge,
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return newMFAPrompt(stream)
},
}
}

func newMFAPrompt(stream *terminal.WSStream) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
var challenge *client.MFAAuthenticateChallenge
var codec protobufMFACodec

// Convert from proto to JSON types.
switch {
Expand Down Expand Up @@ -797,7 +797,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
if t.participantMode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), mfaPrompt(t.stream.WSStream, protobufMFACodec{})); err != nil {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream)); err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks")
return
}
Expand Down

0 comments on commit bb9d991

Please sign in to comment.