Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v17] feat: SSO MFA - Add SSO MFA ceremony #48157

Merged
merged 12 commits into from
Oct 30, 2024
8 changes: 8 additions & 0 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,9 @@ type Config struct {
// MFAPromptConstructor is used to create MFA prompts when needed.
// If nil, the client will not prompt for MFA.
MFAPromptConstructor mfa.PromptConstructor
// SSOMFACeremonyConstructor is used to handle SSO MFA when needed.
// If nil, the client will not prompt for MFA.
SSOMFACeremonyConstructor mfa.SSOMFACeremonyConstructor
}

// CheckAndSetDefaults checks and sets default config values.
Expand Down Expand Up @@ -730,6 +733,11 @@ func (c *Client) SetMFAPromptConstructor(pc mfa.PromptConstructor) {
c.c.MFAPromptConstructor = pc
}

// SetSSOMFACeremonyConstructor sets the SSO MFA ceremony constructor for this client.
func (c *Client) SetSSOMFACeremonyConstructor(scc mfa.SSOMFACeremonyConstructor) {
c.c.SSOMFACeremonyConstructor = scc
}

// Close closes the Client connection to the auth server.
func (c *Client) Close() error {
if c.setClosed() && c.conn != nil {
Expand Down
1 change: 1 addition & 0 deletions api/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func (c *Client) PerformMFACeremony(ctx context.Context, challengeRequest *proto
mfaCeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: c.CreateAuthenticateChallenge,
PromptConstructor: c.c.MFAPromptConstructor,
SSOMFACeremonyConstructor: c.c.SSOMFACeremonyConstructor,
}
return mfaCeremony.Run(ctx, challengeRequest, promptOpts...)
}
26 changes: 26 additions & 0 deletions api/mfa/ceremony.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,21 @@ type Ceremony struct {
CreateAuthenticateChallenge CreateAuthenticateChallengeFunc
// PromptConstructor creates a prompt to prompt the user to solve an authentication challenge.
PromptConstructor PromptConstructor
// SSOMFACeremonyConstructor is an optional SSO MFA ceremony constructor. If provided,
// the MFA ceremony will also attempt to retrieve an SSO MFA challenge.
SSOMFACeremonyConstructor SSOMFACeremonyConstructor
}

// SSOMFACeremony is an SSO MFA ceremony.
type SSOMFACeremony interface {
GetClientCallbackURL() string
Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error)
Close()
}

// SSOMFACeremonyConstructor constructs a new SSO MFA ceremony.
type SSOMFACeremonyConstructor func(ctx context.Context) (SSOMFACeremony, error)

// CreateAuthenticateChallengeFunc is a function that creates an authentication challenge.
type CreateAuthenticateChallengeFunc func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error)

Expand All @@ -54,6 +67,19 @@ func (c *Ceremony) Run(ctx context.Context, req *proto.CreateAuthenticateChallen
return nil, trace.BadParameter("mfa challenge scope must be specified")
}

// If available, prepare an SSO MFA ceremony and set the client redirect URL in the challenge
// request to request an SSO challenge in addition to other challenges.
if c.SSOMFACeremonyConstructor != nil {
ssoMFACeremony, err := c.SSOMFACeremonyConstructor(ctx)
if err != nil {
return nil, trace.Wrap(err, "failed to handle SSO MFA ceremony")
}
defer ssoMFACeremony.Close()

req.SSOClientRedirectURL = ssoMFACeremony.GetClientCallbackURL()
promptOpts = append(promptOpts, withSSOMFACeremony(ssoMFACeremony))
}

chal, err := c.CreateAuthenticateChallenge(ctx, req)
if err != nil {
// CreateAuthenticateChallenge returns a bad parameter error when the client
Expand Down
77 changes: 76 additions & 1 deletion api/mfa/ceremony_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ import (

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/client/proto"
mfav1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/mfa/v1"
"github.com/gravitational/teleport/api/mfa"
)

func TestPerformMFACeremony(t *testing.T) {
func TestMFACeremony(t *testing.T) {
t.Parallel()
ctx := context.Background()

Expand Down Expand Up @@ -128,3 +129,77 @@ func TestPerformMFACeremony(t *testing.T) {
})
}
}

func TestMFACeremony_SSO(t *testing.T) {
t.Parallel()
ctx := context.Background()

testMFAChallenge := &proto.MFAAuthenticateChallenge{
SSOChallenge: &proto.SSOChallenge{
RedirectUrl: "redirect",
RequestId: "request-id",
},
}
testMFAResponse := &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_SSO{
SSO: &proto.SSOResponse{
Token: "token",
RequestId: "request-id",
},
},
}

ssoMFACeremony := &mfa.Ceremony{
CreateAuthenticateChallenge: func(ctx context.Context, req *proto.CreateAuthenticateChallengeRequest) (*proto.MFAAuthenticateChallenge, error) {
return testMFAChallenge, nil
},
PromptConstructor: func(opts ...mfa.PromptOpt) mfa.Prompt {
cfg := new(mfa.PromptConfig)
for _, opt := range opts {
opt(cfg)
}

return mfa.PromptFunc(func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
if cfg.SSOMFACeremony == nil {
return nil, trace.BadParameter("expected sso mfa ceremony")
}

return cfg.SSOMFACeremony.Run(ctx, chal)
})
},
SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) {
return &mockSSOMFACeremony{
clientCallbackURL: "client-redirect",
prompt: func(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return testMFAResponse, nil
},
}, nil
},
}

resp, err := ssoMFACeremony.Run(ctx, &proto.CreateAuthenticateChallengeRequest{
ChallengeExtensions: &mfav1.ChallengeExtensions{
Scope: mfav1.ChallengeScope_CHALLENGE_SCOPE_ADMIN_ACTION,
},
MFARequiredCheck: &proto.IsMFARequiredRequest{},
})
require.NoError(t, err)
require.Equal(t, testMFAResponse, resp)
}

type mockSSOMFACeremony struct {
clientCallbackURL string
prompt mfa.PromptFunc
}

// GetClientCallbackURL returns the client callback URL.
func (m *mockSSOMFACeremony) GetClientCallbackURL() string {
return m.clientCallbackURL
}

// Run the SSO MFA ceremony.
func (m *mockSSOMFACeremony) Run(ctx context.Context, chal *proto.MFAAuthenticateChallenge) (*proto.MFAAuthenticateResponse, error) {
return m.prompt(ctx, chal)
}

func (m *mockSSOMFACeremony) Close() {}
9 changes: 9 additions & 0 deletions api/mfa/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ type PromptConfig struct {
// Extensions are the challenge extensions used to create the prompt's challenge.
// Used to enrich certain prompts.
Extensions *mfav1.ChallengeExtensions
// SSOMFACeremony is an SSO MFA ceremony.
SSOMFACeremony SSOMFACeremony
}

// DeviceDescriptor is a descriptor for a device, such as "registered".
Expand Down Expand Up @@ -117,3 +119,10 @@ func WithPromptChallengeExtensions(exts *mfav1.ChallengeExtensions) PromptOpt {
cfg.Extensions = exts
}
}

// withSSOMFACeremony sets the SSO MFA ceremony for the MFA prompt.
func withSSOMFACeremony(ssoMFACeremony SSOMFACeremony) PromptOpt {
return func(cfg *PromptConfig) {
cfg.SSOMFACeremony = ssoMFACeremony
}
}
12 changes: 9 additions & 3 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,9 @@ type Config struct {
// authenticators, such as remote hosts or virtual machines.
PreferOTP bool

// PreferSSO prefers SSO in favor of other MFA methods.
PreferSSO bool

// CheckVersions will check that client version is compatible
// with auth server version when connecting.
CheckVersions bool
Expand Down Expand Up @@ -3043,6 +3046,8 @@ func (tc *TeleportClient) ConnectToCluster(ctx context.Context) (_ *ClusterClien
return nil, trace.NewAggregate(err, pclt.Close())
}
authClientCfg.MFAPromptConstructor = tc.NewMFAPrompt
authClientCfg.SSOMFACeremonyConstructor = tc.NewSSOMFACeremony

authClient, err := authclient.NewClient(authClientCfg)
if err != nil {
return nil, trace.NewAggregate(err, pclt.Close())
Expand Down Expand Up @@ -5062,9 +5067,10 @@ func (tc *TeleportClient) NewKubernetesServiceClient(ctx context.Context, cluste
Credentials: []client.Credentials{
client.LoadTLS(tlsConfig),
},
ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired,
InsecureAddressDiscovery: tc.InsecureSkipVerify,
MFAPromptConstructor: tc.NewMFAPrompt,
ALPNConnUpgradeRequired: tc.TLSRoutingConnUpgradeRequired,
InsecureAddressDiscovery: tc.InsecureSkipVerify,
MFAPromptConstructor: tc.NewMFAPrompt,
SSOMFACeremonyConstructor: tc.NewSSOMFACeremony,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
5 changes: 3 additions & 2 deletions lib/client/cluster_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ func TestIssueUserCertsWithMFA(t *testing.T) {
tc: &TeleportClient{
localAgent: agent,
Config: Config{
SiteName: "test",
Tracer: tracing.NoopTracer("test"),
WebProxyAddr: "proxy.example.com",
SiteName: "test",
Tracer: tracing.NoopTracer("test"),
MFAPromptConstructor: func(cfg *libmfa.PromptConfig) mfa.Prompt {
return test.prompt
},
Expand Down
19 changes: 19 additions & 0 deletions lib/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
libmfa "github.com/gravitational/teleport/lib/client/mfa"
"github.com/gravitational/teleport/lib/client/sso"
)

// NewMFACeremony returns a new MFA ceremony configured for this client.
func (tc *TeleportClient) NewMFACeremony() *mfa.Ceremony {
return &mfa.Ceremony{
CreateAuthenticateChallenge: tc.createAuthenticateChallenge,
PromptConstructor: tc.NewMFAPrompt,
SSOMFACeremonyConstructor: tc.NewSSOMFACeremony,
}
}

Expand Down Expand Up @@ -61,6 +63,7 @@ func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) mfa.Prompt {
PromptConfig: *cfg,
Writer: tc.Stderr,
PreferOTP: tc.PreferOTP,
PreferSSO: tc.PreferSSO,
AllowStdinHijack: tc.AllowStdinHijack,
StdinFunc: tc.StdinFunc,
})
Expand All @@ -79,5 +82,21 @@ func (tc *TeleportClient) newPromptConfig(opts ...mfa.PromptOpt) *libmfa.PromptC
cfg.WebauthnLoginFunc = tc.WebauthnLogin
cfg.WebauthnSupported = true
}

return cfg
}

// NewSSOMFACeremony creates a new SSO MFA ceremony.
func (tc *TeleportClient) NewSSOMFACeremony(ctx context.Context) (mfa.SSOMFACeremony, error) {
rdConfig, err := tc.ssoRedirectorConfig(ctx, "" /*connectorDisplayName*/)
if err != nil {
return nil, trace.Wrap(err)
}

rd, err := sso.NewRedirector(rdConfig)
if err != nil {
return nil, trace.Wrap(err)
}

return sso.NewCLIMFACeremony(rd), nil
}
Loading
Loading