diff --git a/lib/auth/github.go b/lib/auth/github.go index f07f53299f9ed..d49c789f81349 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -979,14 +979,14 @@ func ValidateClientRedirect(clientRedirect string, ssoTestFlow bool, settings *t // they're used a lot in test code return nil } - if clientRedirect == sso.WebMFARedirect { - // If this is a SSO redirect in the WebUI, allow. - return nil - } u, err := url.Parse(clientRedirect) if err != nil { return trace.Wrap(err, "parsing client redirect URL") } + if u.Path == sso.WebMFARedirect { + // If this is a SSO redirect in the WebUI, allow. + return nil + } if u.Opaque != "" { return trace.BadParameter("unexpected opaque client redirect URL") } diff --git a/lib/client/weblogin.go b/lib/client/weblogin.go index 0d8ca5e99bb4b..7edf946c0e39f 100644 --- a/lib/client/weblogin.go +++ b/lib/client/weblogin.go @@ -474,6 +474,9 @@ type SSOChallenge struct { RequestID string `json:"requestId,omitempty"` RedirectURL string `json:"redirectUrl,omitempty"` Device *SSOMFADevice `json:"device"` + // ChannelID is used by the front end to differentiate multiple ongoing SSO + // MFA requests so they don't interfere with each other. + ChannelID string `json:"channelId"` } // SSOMFADevice is a json compatible [proto.SSOMFADevice]. diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 447713bbe1fda..29c35f77fbc16 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -2222,15 +2222,18 @@ func ConstructSSHResponse(response AuthParams) (*url.URL, error) { return nil, trace.Wrap(err) } + // Extract secret out of the request. + secretKey := u.Query().Get("secret_key") + // We don't use a secret key for WebUI SSO MFA redirects. The request ID itself is // kept a secret on the front end to minimize the risk of a phishing attack. - if response.ClientRedirectURL == sso.WebMFARedirect && response.MFAToken != "" { - u.RawQuery = url.Values{"response": {string(out)}}.Encode() + if secretKey == "" && u.Path == sso.WebMFARedirect && response.MFAToken != "" { + q := u.Query() + q.Add("response", string(out)) + u.RawQuery = q.Encode() return u, nil } - // Extract secret out of the request. - secretKey := u.Query().Get("secret_key") if secretKey == "" { return nil, trace.BadParameter("missing secret_key") } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 425b703f70f6d..8d676cf0d43e8 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -26,6 +26,7 @@ import ( "io" "net" "net/http" + "net/url" "strconv" "strings" "sync" @@ -33,6 +34,7 @@ import ( "time" "github.com/gogo/protobuf/proto" + "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -611,28 +613,39 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te } func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony { + var channelID string return &mfa.Ceremony{ CreateAuthenticateChallenge: createAuthenticateChallenge, PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt { - return newMFAPrompt(stream) + return newMFAPrompt(stream, channelID) }, SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) { + channelID = uuid.NewString() + u, err := url.Parse(sso.WebMFARedirect) + if err != nil { + return nil, trace.Wrap(err) + } + u.RawQuery = url.Values{"channel_id": {channelID}}.Encode() return &sso.MFACeremony{ - ClientCallbackURL: sso.WebMFARedirect, + ClientCallbackURL: u.String(), }, nil }, } } -func newMFAPrompt(stream *terminal.WSStream) mfa.Prompt { +// create a new MFA prompt. When provided, ssoChannelID is used by the front end to differentiate +// between separate ongoing SSO challenges. +func newMFAPrompt(stream *terminal.WSStream, ssoChannelID string) mfa.Prompt { return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) { // Convert from proto to JSON types. var challenge client.MFAAuthenticateChallenge if chal.WebauthnChallenge != nil { challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge) } + if chal.SSOChallenge != nil { challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge) + challenge.SSOChallenge.ChannelID = ssoChannelID } if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil { @@ -804,7 +817,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor if t.participantMode == types.SessionModeratorMode { beforeStart = func(out io.Writer) { nc.OnMFA = func() { - if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream)); err != nil { + if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream, "" /*ssoChannelID*/)); err != nil { t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks") return }