Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v16] Remove SolveAuthenticateChallenge from MFA Ceremony in favor of PromptConstructor #47769

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import (
// Ceremony is an MFA ceremony.
type Ceremony struct {
// CreateAuthenticateChallenge creates an authentication challenge.
CreateAuthenticateChallenge func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)
CreateAuthenticateChallenge CreateAuthenticateChallengeFunc
// PromptConstructor creates a prompt to prompt the user to solve an authentication challenge.
PromptConstructor PromptConstructor
// SolveAuthenticateChallenge solves an authentication challenge. Used in non-interactive settings,
// such as the WebUI with layers abstracting user interaction, and tests.
SolveAuthenticateChallenge func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
}

// CreateAuthenticateChallengeFunc is a function that creates an authentication challenge.
type CreateAuthenticateChallengeFunc func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)

// Run the MFA ceremony.
//
// req may be nil if ceremony.CreateAuthenticateChallenge does not require it, e.g. in
Expand All @@ -44,8 +44,6 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
switch {
case c.CreateAuthenticateChallenge == nil:
return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin")
case c.SolveAuthenticateChallenge != nil && c.PromptConstructor != nil:
return nil, trace.BadParameter("mfa ceremony should have SolveAuthenticateChallenge or PromptConstructor set, not both")
case req == nil:
// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different rpc, e.g. moderated sessions.
Expand All @@ -72,13 +70,8 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
return nil, &ErrMFANotRequired
}

if c.SolveAuthenticateChallenge == nil && c.PromptConstructor == nil {
return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have SolveAuthenticateChallenge or PromptConstructor set in order to succeed")
}

if c.SolveAuthenticateChallenge != nil {
resp, err := c.SolveAuthenticateChallenge(ctx, chal)
return resp, trace.Wrap(err)
if c.PromptConstructor == nil {
return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have PromptConstructor set in order to succeed")
}

resp, err := c.PromptConstructor(promptOpts...).Run(ctx, chal)
Expand Down
40 changes: 8 additions & 32 deletions api/mfa/ceremony_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,6 @@ func TestPerformMFACeremony(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony success solve",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony not required",
ceremony: &mfa.Ceremony{
Expand All @@ -87,8 +73,10 @@ func TestPerformMFACeremony(t *testing.T) {
MFARequired: proto.MFARequired_MFA_REQUIRED_NO,
}, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected mfa not required")
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected mfa not required")
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
Expand All @@ -101,8 +89,10 @@ func TestPerformMFACeremony(t *testing.T) {
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return nil, errors.New("create authenticate challenge failure")
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected challenge failure")
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected challenge failure")
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
Expand All @@ -125,20 +115,6 @@ func TestPerformMFACeremony(t *testing.T) {
assert.ErrorContains(t, err, "prompt mfa failure")
assert.Nil(t, mr)
},
}, {
name: "NOK solve mfa fail",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, errors.New("solve mfa failure")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "solve mfa failure")
assert.Nil(t, mr)
},
},
} {
t.Run(tt.name, func(t *testing.T) {
Expand Down
12 changes: 7 additions & 5 deletions lib/auth/helpers_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ func NewTestDeviceFromChallenge(c *proto.MFARegisterChallenge, opts ...TestDevic
// RegisterTestDevice creates and registers a TestDevice.
// TOTP devices require a clock option.
func RegisterTestDevice(
ctx context.Context, clt authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice, opts ...TestDeviceOpt) (*TestDevice, error) {
ctx context.Context, clt authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice, opts ...TestDeviceOpt,
) (*TestDevice, error) {
dev := &TestDevice{} // Remaining parameters set during registration
for _, opt := range opts {
opt(dev)
Expand All @@ -104,10 +105,12 @@ type authClientI interface {

func (d *TestDevice) registerDevice(ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error {
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(_ context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return authenticator.SolveAuthn(chal)
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return authenticator.SolveAuthn(chal)
})
},
CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge,
}

authnSolved, err := mfaCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
Expand Down Expand Up @@ -212,7 +215,6 @@ func (d *TestDevice) solveRegister(c *proto.MFARegisterChallenge) (*proto.MFAReg
default:
return nil, trace.BadParameter("unexpected challenge type: %T", c.Request)
}

}

func (d *TestDevice) solveRegisterWebauthn(c *proto.MFARegisterChallenge) (*proto.MFARegisterResponse, error) {
Expand Down
26 changes: 14 additions & 12 deletions lib/client/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai
}

mfaCeremony := &mfa.Ceremony{
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n")

mfaResp, err := mfaPrompt.Run(ctx, chal)
if err != nil {
fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err)
return nil, trace.Wrap(err)
}

fmt.Fprint(term, "\r\nTeleport > Received MFA presence confirmation\r\n")
return mfaResp, nil
})
},
CreateAuthenticateChallenge: func(ctx context.Context, _ *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
req := &proto.PresenceMFAChallengeSend{
Request: &proto.PresenceMFAChallengeSend_ChallengeRequest{
Expand All @@ -99,18 +113,6 @@ func RunPresenceTask(ctx context.Context, term io.Writer, maintainer PresenceMai

return challenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
fmt.Fprint(term, "\r\nTeleport > Please tap your MFA key\r\n")

mfaResp, err := mfaPrompt.Run(ctx, chal)
if err != nil {
fmt.Fprintf(term, "\r\nTeleport > Failed to confirm presence: %v\r\n", err)
return nil, trace.Wrap(err)
}

fmt.Fprint(term, "\r\nTeleport > Received MFA presence confirmation\r\n")
return mfaResp, nil
},
}

for {
Expand Down
96 changes: 49 additions & 47 deletions lib/web/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -377,65 +377,67 @@ func (h *Handler) performSessionMFACeremony(
}()

mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
codec := tdpMFACodec{}

// Send the challenge over the socket.
msg, err := codec.Encode(
&client.MFAAuthenticateChallenge{
WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge),
},
defaults.WebsocketWebauthnChallenge,
)
if err != nil {
return nil, trace.Wrap(err)
}

if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return nil, trace.Wrap(err)
}

span.AddEvent("waiting for user to complete mfa ceremony")
var buf []byte
// Loop through incoming messages until we receive an MFA message that lets us
// complete the ceremony. Non-MFA messages (e.g. ClientScreenSpecs representing
// screen resizes) are withheld for later.
for {
var ty int
ty, buf, err = ws.ReadMessage()
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
codec := tdpMFACodec{}

// Send the challenge over the socket.
msg, err := codec.Encode(
&client.MFAAuthenticateChallenge{
WebauthnChallenge: wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge),
},
defaults.WebsocketWebauthnChallenge,
)
if err != nil {
return nil, trace.Wrap(err)
}
if ty != websocket.BinaryMessage {
return nil, trace.BadParameter("received unexpected web socket message type %d", ty)
}
if len(buf) == 0 {
return nil, trace.BadParameter("empty message received")

if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return nil, trace.Wrap(err)
}

if tdp.MessageType(buf[0]) != tdp.TypeMFA {
// This is not an MFA message, withhold it for later.
msg, err := tdp.Decode(buf)
h.log.Debugf("Received non-MFA message, withholding:", msg)
span.AddEvent("waiting for user to complete mfa ceremony")
var buf []byte
// Loop through incoming messages until we receive an MFA message that lets us
// complete the ceremony. Non-MFA messages (e.g. ClientScreenSpecs representing
// screen resizes) are withheld for later.
for {
var ty int
ty, buf, err = ws.ReadMessage()
if err != nil {
return nil, trace.Wrap(err)
}
*withheld = append(*withheld, msg)
continue
}
if ty != websocket.BinaryMessage {
return nil, trace.BadParameter("received unexpected web socket message type %d", ty)
}
if len(buf) == 0 {
return nil, trace.BadParameter("empty message received")
}

break
}
if tdp.MessageType(buf[0]) != tdp.TypeMFA {
// This is not an MFA message, withhold it for later.
msg, err := tdp.Decode(buf)
h.log.Debugf("Received non-MFA message, withholding:", msg)
if err != nil {
return nil, trace.Wrap(err)
}
*withheld = append(*withheld, msg)
continue
}

assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge)
if err != nil {
return nil, trace.Wrap(err)
}
span.AddEvent("mfa ceremony completed")
break
}

return assertion, nil
assertion, err := codec.DecodeResponse(buf, defaults.WebsocketWebauthnChallenge)
if err != nil {
return nil, trace.Wrap(err)
}
span.AddEvent("mfa ceremony completed")

return assertion, nil
})
},
CreateAuthenticateChallenge: sctx.cfg.RootClient.CreateAuthenticateChallenge,
}

_, newCerts, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
Expand Down
11 changes: 1 addition & 10 deletions lib/web/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ import (

clientproto "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth/authclient"
Expand Down Expand Up @@ -236,18 +235,10 @@ func (p *podHandler) handler(r *http.Request) error {
Usage: clientproto.UserCertsRequest_Kubernetes,
}

mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: p.sctx.cfg.RootClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(ctx context.Context, chal *clientproto.MFAAuthenticateChallenge) (*clientproto.MFAAuthenticateResponse, error) {
assertion, err := mfaPrompt(stream.WSStream, protobufMFACodec{}).Run(ctx, chal)
return assertion, trace.Wrap(err)
},
}

_, certs, err := client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
CurrentAuthClient: p.userClient,
RootAuthClient: p.sctx.cfg.RootClient,
MFACeremony: mfaCeremony,
MFACeremony: newMFACeremony(stream.WSStream, p.sctx.cfg.RootClient.CreateAuthenticateChallenge),
MFAAgainstRoot: p.sctx.cfg.RootClusterName == p.teleportCluster,
MFARequiredReq: &clientproto.IsMFARequiredRequest{
Target: &clientproto.IsMFARequiredRequest_KubernetesCluster{KubernetesCluster: p.req.KubeCluster},
Expand Down
26 changes: 13 additions & 13 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,20 +591,10 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
SSHLogin: tc.HostLogin,
}

mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: t.ctx.cfg.RootClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
span.AddEvent("prompting user with mfa challenge")
assertion, err := mfaPrompt(wsStream, protobufMFACodec{}).Run(ctx, chal)
span.AddEvent("user completed mfa challenge")
return assertion, trace.Wrap(err)
},
}

key, _, err = client.PerformSessionMFACeremony(ctx, client.PerformSessionMFACeremonyParams{
CurrentAuthClient: t.userAuthClient,
RootAuthClient: t.ctx.cfg.RootClient,
MFACeremony: mfaCeremony,
MFACeremony: newMFACeremony(wsStream, t.ctx.cfg.RootClient.CreateAuthenticateChallenge),
MFAAgainstRoot: t.ctx.cfg.RootClusterName == tc.SiteName,
MFARequiredReq: mfaRequiredReq,
CertsReq: certsReq,
Expand All @@ -623,9 +613,19 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
return []ssh.AuthMethod{am}, nil
}

func mfaPrompt(stream *terminal.WSStream, codec terminal.MFACodec) mfa.Prompt {
func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony {
return &mfa.Ceremony{
CreateAuthenticateChallenge: createAuthenticateChallenge,
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return newMFAPrompt(stream)
},
}
}

func newMFAPrompt(stream *terminal.WSStream) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
var challenge *client.MFAAuthenticateChallenge
var codec protobufMFACodec

// Convert from proto to JSON types.
switch {
Expand Down Expand Up @@ -801,7 +801,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
if t.participantMode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), mfaPrompt(t.stream.WSStream, protobufMFACodec{})); err != nil {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream)); err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks")
return
}
Expand Down
Loading