Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify and simplify MFA Ceremony helpers #46986

Merged
merged 12 commits into from
Oct 2, 2024
20 changes: 4 additions & 16 deletions api/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package client
import (
"context"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
)
Expand All @@ -29,19 +27,9 @@ import (
// and prompts the user to answer the challenge with the given promptOpts, and ultimately returning
// an MFA challenge response for the user.
func (c *Client) PerformMFACeremony(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
// Don't attempt the MFA ceremony if we can't prompt for a response.
if c.c.MFAPromptConstructor == nil {
return nil, trace.Wrap(&mfa.ErrMFANotSupported, "missing MFAPromptConstructor field, client cannot perform MFA ceremony")
}

return mfa.PerformMFACeremony(ctx, c, challengeRequest, promptOpts...)
}

// PromptMFA prompts the user for MFA. Implements [mfa.MFACeremonyClient].
func (c *Client) PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if c.c.MFAPromptConstructor == nil {
return nil, trace.Wrap(&mfa.ErrMFANotSupported, "missing MFAPromptConstructor field, client cannot prompt for MFA")
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: c.CreateAuthenticateChallenge,
PromptConstructor: c.c.MFAPromptConstructor,
}

return c.c.MFAPromptConstructor(promptOpts...).Run(ctx, chal)
return mfaCeremony.Run(ctx, challengeRequest, promptOpts...)
}
66 changes: 41 additions & 25 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,37 @@ import (
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
)

// MFACeremonyClient is a client that can perform an MFA ceremony, from retrieving
// the MFA challenge to prompting for an MFA response from the user.
type MFACeremonyClient interface {
// CreateAuthenticateChallenge creates and returns MFA challenges for a users registered MFA devices.
CreateAuthenticateChallenge(ctx context.Context, in *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)
// PromptMFA prompts the user for MFA.
PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error)
// Ceremony is an MFA ceremony.
type Ceremony struct {
// CreateAuthenticateChallenge creates an authentication challenge.
CreateAuthenticateChallenge func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)
// PromptConstructor creates a prompt to prompt the user to solve an authentication challenge.
PromptConstructor PromptConstructor
// SolveAuthenticateChallenge solves an authentication challenge. Used in non-interactive settings,
// such as the WebUI with layers abstracting user interaction, and tests.
SolveAuthenticateChallenge func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
}

// PerformMFACeremony retrieves an MFA challenge from the server with the given challenge extensions
// and prompts the user to answer the challenge with the given promptOpts, and ultimately returning
// an MFA challenge response for the user.
func PerformMFACeremony(ctx context.Context, clt MFACeremonyClient, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if challengeRequest == nil {
return nil, trace.BadParameter("missing challenge request")
}

if challengeRequest.ChallengeExtensions == nil {
// Run the MFA ceremony.
//
// req may be nil if ceremony.CreateAuthenticateChallenge does not require it, e.g. in
// the moderated session mfa ceremony which uses a custom stream rpc to create challenges.
func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error) {
switch {
case c.CreateAuthenticateChallenge == nil:
return nil, trace.BadParameter("mfa ceremony must have CreateAuthenticateChallenge set in order to begin")
case c.SolveAuthenticateChallenge != nil && c.PromptConstructor != nil:
return nil, trace.BadParameter("mfa ceremony should have SolveAuthenticateChallenge or PromptConstructor set, not both")
case req == nil:
// req may be nil in cases where the ceremony's CreateAuthenticateChallenge sources
// its own req or uses a different rpc, e.g. moderated sessions.
case req.ChallengeExtensions == nil:
return nil, trace.BadParameter("missing challenge extensions")
}

if challengeRequest.ChallengeExtensions.Scope == mfav1.ChallengeScope_CHALLENGE_SCOPE_UNSPECIFIED {
case req.ChallengeExtensions.Scope == mfav1.ChallengeScope_CHALLENGE_SCOPE_UNSPECIFIED:
return nil, trace.BadParameter("mfa challenge scope must be specified")
}
Joerger marked this conversation as resolved.
Show resolved Hide resolved

chal, err := clt.CreateAuthenticateChallenge(ctx, challengeRequest)
chal, err := c.CreateAuthenticateChallenge(ctx, req)
if err != nil {
// CreateAuthenticateChallenge returns a bad parameter error when the client
// user is not a Teleport user - for example, the AdminRole. Treat this as an MFA
Expand All @@ -67,21 +72,31 @@ func PerformMFACeremony(ctx context.Context, clt MFACeremonyClient, challengeReq
return nil, &ErrMFANotRequired
}

return clt.PromptMFA(ctx, chal, promptOpts...)
if c.SolveAuthenticateChallenge == nil && c.PromptConstructor == nil {
return nil, trace.Wrap(&ErrMFANotSupported, "mfa ceremony must have SolveAuthenticateChallenge or PromptConstructor set in order to succeed")
}

if c.SolveAuthenticateChallenge != nil {
resp, err := c.SolveAuthenticateChallenge(ctx, chal)
return resp, trace.Wrap(err)
}

resp, err := c.PromptConstructor(promptOpts...).Run(ctx, chal)
return resp, trace.Wrap(err)
}

type MFACeremony func(ctx context.Context, challengeRequest *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error)
// CeremonyFn is a function that will carry out an MFA ceremony.
type CeremonyFn func(ctx context.Context, in *proto.CreateAuthenticateChallengeRequest, promptOpts ...PromptOpt) (*proto.MFAAuthenticateResponse, error)

// PerformAdminActionMFACeremony retrieves an MFA challenge from the server for an admin
// action, prompts the user to answer the challenge, and returns the resulting MFA response.
func PerformAdminActionMFACeremony(ctx context.Context, mfaCeremony MFACeremony, allowReuse bool) (*proto.MFAAuthenticateResponse, error) {
func PerformAdminActionMFACeremony(ctx context.Context, mfaCeremony CeremonyFn, allowReuse bool) (*proto.MFAAuthenticateResponse, error) {
allowReuseExt := mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_NO
if allowReuse {
allowReuseExt = mfav1.ChallengeAllowReuse_CHALLENGE_ALLOW_REUSE_YES
}

challengeRequest := &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{},
MFARequiredCheck: &proto.IsMFARequiredRequest{
Target: &proto.IsMFARequiredRequest_AdminAction{
AdminAction: &proto.AdminAction{},
Expand All @@ -93,5 +108,6 @@ func PerformAdminActionMFACeremony(ctx context.Context, mfaCeremony MFACeremony,
},
}

return mfaCeremony(ctx, challengeRequest, WithPromptReasonAdminAction())
resp, err := mfaCeremony(ctx, challengeRequest, WithPromptReasonAdminAction())
return resp, trace.Wrap(err)
}
114 changes: 69 additions & 45 deletions api/mfa/ceremony_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"

"github.com/gravitational/teleport/api/client/proto"
Expand All @@ -32,6 +33,9 @@ func TestPerformMFACeremony(t *testing.T) {
t.Parallel()
ctx := context.Background()

testMFAChallenge := &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
}
testMFAResponse := &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_TOTP{
TOTP: &proto.TOTPResponse{
Expand All @@ -42,52 +46,103 @@ func TestPerformMFACeremony(t *testing.T) {

for _, tt := range []struct {
name string
ceremonyClient *fakeMFACeremonyClient
ceremony *mfa.Ceremony
assertCeremonyResponse func(*testing.T, *proto.MFAAuthenticateResponse, error, ...interface{})
}{
{
name: "OK ceremony success",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
name: "OK ceremony success prompt",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony success solve",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.NoError(t, err)
assert.Equal(t, testMFAResponse, mr)
},
}, {
name: "OK ceremony not required",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
mfaRequired: proto.MFARequired_MFA_REQUIRED_NO,
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return &proto.MFAAuthenticateChallenge{
MFARequired: proto.MFARequired_MFA_REQUIRED_NO,
}, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected mfa not required")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.Error(t, err, mfa.ErrMFANotRequired)
assert.Nil(t, mr)
},
}, {
name: "NOK create challenge fail",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
createAuthenticateChallengeErr: errors.New("create authenticate challenge failure"),
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return nil, errors.New("create authenticate challenge failure")
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, trace.BadParameter("expected challenge failure")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "create authenticate challenge failure")
assert.Nil(t, mr)
},
}, {
name: "NOK prompt mfa fail",
ceremonyClient: &fakeMFACeremonyClient{
challengeResponse: testMFAResponse,
promptMFAErr: errors.New("prompt mfa failure"),
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
PromptConstructor: func(po ...mfa.PromptOpt) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, errors.New("prompt mfa failure")
})
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "prompt mfa failure")
assert.Nil(t, mr)
},
}, {
name: "NOK solve mfa fail",
ceremony: &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
SolveAuthenticateChallenge: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return nil, errors.New("solve mfa failure")
},
},
assertCeremonyResponse: func(t *testing.T, mr *proto.MFAAuthenticateResponse, err error, i ...interface{}) {
assert.ErrorContains(t, err, "solve mfa failure")
assert.Nil(t, mr)
},
},
} {
t.Run(tt.name, func(t *testing.T) {
resp, err := mfa.PerformMFACeremony(ctx, tt.ceremonyClient, &proto.CreateAuthenticateChallengeRequest{
resp, err := tt.ceremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_ADMIN_ACTION,
},
Expand All @@ -97,34 +152,3 @@ func TestPerformMFACeremony(t *testing.T) {
})
}
}

type fakeMFACeremonyClient struct {
createAuthenticateChallengeErr error
promptMFAErr error
mfaRequired proto.MFARequired
challengeResponse *proto.MFAAuthenticateResponse
}

func (c *fakeMFACeremonyClient) CreateAuthenticateChallenge(ctx context.Context, in *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
if c.createAuthenticateChallengeErr != nil {
return nil, c.createAuthenticateChallengeErr
}

chal := &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
}

if in.MFARequiredCheck != nil {
chal.MFARequired = c.mfaRequired
}

return chal, nil
}

func (c *fakeMFACeremonyClient) PromptMFA(ctx context.Context, chal *proto.MFAAuthenticateChallenge, promptOpts ...mfa.PromptOpt) (*proto.MFAAuthenticateResponse, error) {
if c.promptMFAErr != nil {
return nil, c.promptMFAErr
}

return c.challengeResponse, nil
}
2 changes: 1 addition & 1 deletion api/utils/grpc/interceptors/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
// to the rpc call when an MFA response is provided through the context. Additionally,
// when the call returns an error that indicates that MFA is required, this interceptor
// will prompt for MFA using the given mfaCeremony and retry.
func WithMFAUnaryInterceptor(mfaCeremony mfa.MFACeremony) grpc.UnaryClientInterceptor {
func WithMFAUnaryInterceptor(mfaCeremony mfa.CeremonyFn) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
// Check for MFA response passed through the context.
if mfaResp, err := mfa.MFAResponseFromContext(ctx); err == nil {
Expand Down
19 changes: 9 additions & 10 deletions lib/auth/helpers_mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"github.com/gravitational/teleport/api/client/proto"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
"github.com/gravitational/teleport/api/mfa"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/mocku2f"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
Expand Down Expand Up @@ -101,24 +102,22 @@ type authClientI interface {
AddMFADeviceSync(context.Context, *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error)
}

func (d *TestDevice) registerDevice(
ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error {
// Re-authenticate using MFA.
authnChal, err := authClient.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{
Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{
ContextUser: &proto.ContextUser{},
func (d *TestDevice) registerDevice(ctx context.Context, authClient authClientI, devName string, devType proto.DeviceType, authenticator *TestDevice) error {
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: authClient.CreateAuthenticateChallenge,
SolveAuthenticateChallenge: func(_ context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return authenticator.SolveAuthn(chal)
},
}

authnSolved, err := mfaCeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
Joerger marked this conversation as resolved.
Show resolved Hide resolved
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_MANAGE_DEVICES,
},
})
if err != nil {
return trace.Wrap(err)
}
authnSolved, err := authenticator.SolveAuthn(authnChal)
if err != nil {
return trace.Wrap(err)
}

// Acquire and solve registration challenge.
usage := proto.DeviceUsage_DEVICE_USAGE_MFA
Expand Down
Loading
Loading