Skip to content

Commit

Permalink
Add sso channel ID for front end broadcast channel logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Oct 23, 2024
1 parent 316f3d9 commit 027a025
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 8 deletions.
3 changes: 3 additions & 0 deletions lib/client/weblogin.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ type SSOChallenge struct {
RequestID string `json:"requestId,omitempty"`
RedirectURL string `json:"redirectUrl,omitempty"`
Device *SSOMFADevice `json:"device"`
// ChannelID is used by the front end to differentiate multiple ongoing SSO
// MFA requests so they don't interfere with each other.
ChannelID string `json:"channelId"`
}

// SSOMFADevice is a json compatible [proto.SSOMFADevice].
Expand Down
11 changes: 7 additions & 4 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2208,15 +2208,18 @@ func ConstructSSHResponse(response AuthParams) (*url.URL, error) {
return nil, trace.Wrap(err)
}

// Extract secret out of the request.
secretKey := u.Query().Get("secret_key")

// We don't use a secret key for WebUI SSO MFA redirects. The request ID itself is
// kept a secret on the front end to minimize the risk of a phishing attack.
if response.ClientRedirectURL == sso.WebMFARedirect && response.MFAToken != "" {
u.RawQuery = url.Values{"response": {string(out)}}.Encode()
if secretKey == "" && u.Path == sso.WebMFARedirect && response.MFAToken != "" {
q := u.Query()
q.Add("response", string(out))
u.RawQuery = q.Encode()
return u, nil
}

// Extract secret out of the request.
secretKey := u.Query().Get("secret_key")
if secretKey == "" {
return nil, trace.BadParameter("missing secret_key")
}
Expand Down
21 changes: 17 additions & 4 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ import (
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/gogo/protobuf/proto"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -611,28 +613,39 @@ func (t *sshBaseHandler) issueSessionMFACerts(ctx context.Context, tc *client.Te
}

func newMFACeremony(stream *terminal.WSStream, createAuthenticateChallenge mfa.CreateAuthenticateChallengeFunc) *mfa.Ceremony {
var channelID string
return &mfa.Ceremony{
CreateAuthenticateChallenge: createAuthenticateChallenge,
PromptConstructor: func(...mfa.PromptOpt) mfa.Prompt {
return newMFAPrompt(stream)
return newMFAPrompt(stream, channelID)
},
SSOMFACeremonyConstructor: func(ctx context.Context) (mfa.SSOMFACeremony, error) {
channelID = uuid.NewString()
u, err := url.Parse(sso.WebMFARedirect)
if err != nil {
return nil, trace.Wrap(err)
}
u.RawQuery = url.Values{"channel_id": {channelID}}.Encode()
return &sso.MFACeremony{
ClientCallbackURL: sso.WebMFARedirect,
ClientCallbackURL: u.String(),
}, nil
},
}
}

func newMFAPrompt(stream *terminal.WSStream) mfa.Prompt {
// create a new MFA prompt. When provided, ssoChannelID is used by the front end to differentiate
// between separate ongoing SSO challenges.
func newMFAPrompt(stream *terminal.WSStream, ssoChannelID string) mfa.Prompt {
return mfa.PromptFunc(func(ctx context.Context, chal *authproto.MFAAuthenticateChallenge) (*authproto.MFAAuthenticateResponse, error) {
// Convert from proto to JSON types.
var challenge client.MFAAuthenticateChallenge
if chal.WebauthnChallenge != nil {
challenge.WebauthnChallenge = wantypes.CredentialAssertionFromProto(chal.WebauthnChallenge)
}

if chal.SSOChallenge != nil {
challenge.SSOChallenge = client.SSOChallengeFromProto(chal.SSOChallenge)
challenge.SSOChallenge.ChannelID = ssoChannelID
}

if chal.WebauthnChallenge == nil && chal.SSOChallenge == nil {
Expand Down Expand Up @@ -804,7 +817,7 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor
if t.participantMode == types.SessionModeratorMode {
beforeStart = func(out io.Writer) {
nc.OnMFA = func() {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream)); err != nil {
if err := t.presenceChecker(ctx, out, t.userAuthClient, t.sessionData.ID.String(), newMFAPrompt(t.stream.WSStream, "" /*ssoChannelID*/)); err != nil {
t.log.WithError(err).Warn("Unable to stream terminal - failure performing presence checks")
return
}
Expand Down

0 comments on commit 027a025

Please sign in to comment.