Skip to content

Commit

Permalink
fix: Always prompt if multiple MFA methods are running (#47114) (#47154)
Browse files Browse the repository at this point in the history
* fix: Always prompt if multiple MFA methods are running

* Fix deadlock on OTP+WebAuthn+PIN prompts

* Fix races on wanwin.PromptPlatformMessage
  • Loading branch information
codingllama authored Oct 7, 2024
1 parent 5baa3a6 commit 7ad25b5
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 63 deletions.
41 changes: 31 additions & 10 deletions lib/auth/webauthnwin/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"fmt"
"io"
"os"
"sync"

"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/protocol/webauthncose"
Expand Down Expand Up @@ -166,22 +167,42 @@ func Register(_ context.Context, origin string, cc *wantypes.CredentialCreation)

const defaultPromptMessage = "Using platform authenticator, follow the OS dialogs"

var (
// PromptPlatformMessage is the message shown before system prompts.
PromptPlatformMessage = defaultPromptMessage
// promptPlatformMessage is the message shown before system prompts.
var promptPlatformMessage = struct {
mu sync.Mutex
message string
}{
message: defaultPromptMessage,
}

// PromptWriter is the writer used for prompt messages.
PromptWriter io.Writer = os.Stderr
)
// PromptWriter is the writer used for prompt messages.
var PromptWriter io.Writer = os.Stderr

// ResetPromptPlatformMessage resets [PromptPlatformMessage] to its original state.
// SetPromptPlatformMessage assigns a new prompt platform message. The prompt
// platform message is shown by [Login] or [Register] when prompting for a
// device touch.
//
// See [ResetPromptPlatformMessage].
func SetPromptPlatformMessage(message string) {
promptPlatformMessage.mu.Lock()
promptPlatformMessage.message = message
promptPlatformMessage.mu.Unlock()
}

// ResetPromptPlatformMessage resets the prompt platform message to its original
// state.
//
// See [SetPromptPlatformMessage].
func ResetPromptPlatformMessage() {
PromptPlatformMessage = defaultPromptMessage
SetPromptPlatformMessage(defaultPromptMessage)
}

func promptPlatform() {
if PromptPlatformMessage != "" {
fmt.Fprintln(PromptWriter, PromptPlatformMessage)
promptPlatformMessage.mu.Lock()
defer promptPlatformMessage.mu.Unlock()

if msg := promptPlatformMessage.message; msg != "" {
fmt.Fprintln(PromptWriter, msg)
}
}

Expand Down
7 changes: 3 additions & 4 deletions lib/client/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@ import (

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/mfa"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
libmfa "github.com/gravitational/teleport/lib/client/mfa"
)

// WebauthnLoginFunc matches the signature of [wancli.Login].
type WebauthnLoginFunc func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error)
// WebauthnLoginFunc is a function that performs WebAuthn login.
// Mimics the signature of [webauthncli.Login].
type WebauthnLoginFunc = libmfa.WebauthnLoginFunc

// NewMFAPrompt creates a new MFA prompt from client settings.
func (tc *TeleportClient) NewMFAPrompt(opts ...mfa.PromptOpt) mfa.Prompt {
Expand Down
94 changes: 63 additions & 31 deletions lib/client/mfa/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import (
"context"
"fmt"
"io"
"runtime"
"sync"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils/prompt"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
Expand Down Expand Up @@ -65,26 +67,42 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng

// Depending on the run opts, we may spawn a TOTP goroutine, webauth goroutine, or both.
spawnGoroutines := func(ctx context.Context, wg *sync.WaitGroup, respC chan<- MFAGoroutineResponse) {
// Use variables below to cancel OTP reads and make sure the goroutine exited.
otpCtx, otpCancel := context.WithCancel(ctx)
otpDone := make(chan struct{})
otpCancelAndWait := func() {
otpCancel()
<-otpDone
dualPrompt := runOpts.PromptTOTP && runOpts.PromptWebauthn

// Print the prompt message directly here in case of dualPrompt.
// This avoids problems with a goroutine failing before any message is
// printed.
if dualPrompt {
var message string
if runtime.GOOS == constants.WindowsOS {
message = "Follow the OS dialogs for platform authentication, or enter an OTP code here:"
webauthnwin.SetPromptPlatformMessage("")
} else {
message = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", c.promptDevicePrefix(), c.promptDevicePrefix())
}
fmt.Fprintln(c.writer, message)
}

// Fire TOTP goroutine.
var otpCancelAndWait func()
if runOpts.PromptTOTP {
otpCtx, otpCancel := context.WithCancel(ctx)
otpDone := make(chan struct{})
otpCancelAndWait = func() {
otpCancel()
<-otpDone
}

wg.Add(1)
go func() {
defer wg.Done()
defer otpCancel()
defer close(otpDone)

// Let Webauthn take the prompt below if applicable.
quiet := c.cfg.Quiet || runOpts.PromptWebauthn

resp, err := c.promptTOTP(otpCtx, chal, quiet)
defer func() {
wg.Done()
otpCancel()
close(otpDone)
}()

quiet := c.cfg.Quiet || dualPrompt
resp, err := c.promptTOTP(otpCtx, quiet)
respC <- MFAGoroutineResponse{Resp: resp, Err: trace.Wrap(err, "TOTP authentication failed")}
}()
}
Expand All @@ -93,11 +111,15 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng
if runOpts.PromptWebauthn {
wg.Add(1)
go func() {
defer wg.Done()
defer func() {
wg.Done()
// Important for dual-prompt, harmless otherwise.
webauthnwin.ResetPromptPlatformMessage()
}()

// Get webauthn prompt and wrap with otp context handler.
prompt := &webauthnPromptWithOTP{
LoginPrompt: c.getWebauthnPrompt(ctx, runOpts.PromptTOTP),
LoginPrompt: c.getWebauthnPrompt(ctx, dualPrompt),
otpCancelAndWait: otpCancelAndWait,
}

Expand All @@ -110,7 +132,7 @@ func (c *CLIPrompt) Run(ctx context.Context, chal *proto.MFAAuthenticateChalleng
return HandleMFAPromptGoroutines(ctx, spawnGoroutines)
}

func (c *CLIPrompt) promptTOTP(ctx context.Context, chal *proto.MFAAuthenticateChallenge, quiet bool) (*proto.MFAAuthenticateResponse, error) {
func (c *CLIPrompt) promptTOTP(ctx context.Context, quiet bool) (*proto.MFAAuthenticateResponse, error) {
var msg string
if !quiet {
msg = fmt.Sprintf("Enter an OTP code from a %sdevice", c.promptDevicePrefix())
Expand All @@ -128,7 +150,7 @@ func (c *CLIPrompt) promptTOTP(ctx context.Context, chal *proto.MFAAuthenticateC
}, nil
}

func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, withTOTP bool) wancli.LoginPrompt {
func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, dualPrompt bool) wancli.LoginPrompt {
writer := c.writer
if c.cfg.Quiet {
writer = io.Discard
Expand All @@ -138,13 +160,10 @@ func (c *CLIPrompt) getWebauthnPrompt(ctx context.Context, withTOTP bool) wancli
prompt.SecondTouchMessage = fmt.Sprintf("Tap your %ssecurity key to complete login", c.promptDevicePrefix())
prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key", c.promptDevicePrefix())

if withTOTP {
prompt.FirstTouchMessage = fmt.Sprintf("Tap any %ssecurity key or enter a code from a %sOTP device", c.promptDevicePrefix(), c.promptDevicePrefix())

// Customize Windows prompt directly.
// Note that the platform popup is a modal and will only go away if canceled.
webauthnwin.PromptPlatformMessage = "Follow the OS dialogs for platform authentication, or enter an OTP code here:"
defer webauthnwin.ResetPromptPlatformMessage()
// Skip when both OTP and WebAuthn are possible, as the prompt happens
// externally.
if dualPrompt {
prompt.FirstTouchMessage = ""
}

return prompt
Expand Down Expand Up @@ -173,12 +192,25 @@ func (c *CLIPrompt) promptDevicePrefix() string {
// authenticators out there.
type webauthnPromptWithOTP struct {
wancli.LoginPrompt
otpCancelAndWait func()

otpCancelAndWaitOnce sync.Once
otpCancelAndWait func()
}

func (w *webauthnPromptWithOTP) PromptPIN() (string, error) {
// If we get to this stage, Webauthn PIN verification is underway.
// Cancel otp goroutine so that it doesn't capture the PIN from stdin.
w.otpCancelAndWait()
return w.LoginPrompt.PromptPIN()
func (w *webauthnPromptWithOTP) PromptTouch() (wancli.TouchAcknowledger, error) {
ack, err := w.LoginPrompt.PromptTouch()
if err != nil {
return nil, trace.Wrap(err)
}

return func() error {
err := ack()

// Stop the OTP goroutine when the first touch is acknowledged.
if w.otpCancelAndWait != nil {
w.otpCancelAndWaitOnce.Do(w.otpCancelAndWait)
}

return trace.Wrap(err)
}, nil
}
132 changes: 117 additions & 15 deletions lib/client/mfa/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package mfa_test
import (
"bytes"
"context"
"errors"
"testing"
"time"

Expand All @@ -39,12 +40,13 @@ func TestCLIPrompt(t *testing.T) {
ctx := context.Background()

for _, tc := range []struct {
name string
stdin string
challenge *proto.MFAAuthenticateChallenge
expectErr error
expectStdOut string
expectResp *proto.MFAAuthenticateResponse
name string
stdin string
challenge *proto.MFAAuthenticateChallenge
expectErr error
expectStdOut string
expectResp *proto.MFAAuthenticateResponse
makeWebauthnLoginFunc func(stdin *prompt.FakeReader) mfa.WebauthnLoginFunc
}{
{
name: "OK empty challenge",
Expand Down Expand Up @@ -126,6 +128,102 @@ func TestCLIPrompt(t *testing.T) {
},
expectErr: context.DeadlineExceeded,
},
{
name: "OK otp and webauthn with PIN",
challenge: &proto.MFAAuthenticateChallenge{
TOTP: &proto.TOTPChallenge{},
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
},
expectStdOut: `Tap any security key or enter a code from a OTP device
Detected security key tap
Enter your security key PIN:
`,
expectResp: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: &webauthnpb.CredentialAssertionResponse{
RawId: []byte{1, 2, 3, 4, 5},
},
},
},
makeWebauthnLoginFunc: func(stdin *prompt.FakeReader) mfa.WebauthnLoginFunc {
return func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
ack, err := prompt.PromptTouch()
if err != nil {
return nil, "", trace.Wrap(err)
}

// Ack first (so the OTP goroutine stops)...
if err := ack(); err != nil {
return nil, "", trace.Wrap(err)
}

// ...then send the PIN to stdin...
const pin = "1234"
stdin.AddString(pin)

// ...then prompt for the PIN.
switch got, err := prompt.PromptPIN(); {
case err != nil:
return nil, "", trace.Wrap(err)
case got != pin:
return nil, "", errors.New("invalid PIN")
}

return &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: &webauthnpb.CredentialAssertionResponse{
RawId: []byte{1, 2, 3, 4, 5},
},
},
}, "", nil
}
},
},
{
name: "OK webauthn with PIN",
challenge: &proto.MFAAuthenticateChallenge{
TOTP: nil, // no TOTP challenge
WebauthnChallenge: &webauthnpb.CredentialAssertion{},
},
stdin: "1234",
expectStdOut: `Tap any security key
Detected security key tap
Enter your security key PIN:
`,
expectResp: &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: &webauthnpb.CredentialAssertionResponse{
RawId: []byte{1, 2, 3, 4, 5},
},
},
},
makeWebauthnLoginFunc: func(_ *prompt.FakeReader) mfa.WebauthnLoginFunc {
return func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
ack, err := prompt.PromptTouch()
if err != nil {
return nil, "", trace.Wrap(err)
}
if err := ack(); err != nil {
return nil, "", trace.Wrap(err)
}

switch got, err := prompt.PromptPIN(); {
case err != nil:
return nil, "", trace.Wrap(err)
case got != "1234":
return nil, "", errors.New("invalid PIN")
}

return &proto.MFAAuthenticateResponse{
Response: &proto.MFAAuthenticateResponse_Webauthn{
Webauthn: &webauthnpb.CredentialAssertionResponse{
RawId: []byte{1, 2, 3, 4, 5},
},
},
}, "", nil
}
},
},
} {
t.Run(tc.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
Expand All @@ -143,17 +241,21 @@ func TestCLIPrompt(t *testing.T) {
cfg := mfa.NewPromptConfig("proxy.example.com")
cfg.AllowStdinHijack = true
cfg.WebauthnSupported = true
cfg.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
if _, err := prompt.PromptTouch(); err != nil {
return nil, "", trace.Wrap(err)
}
if tc.makeWebauthnLoginFunc != nil {
cfg.WebauthnLoginFunc = tc.makeWebauthnLoginFunc(stdin)
} else {
cfg.WebauthnLoginFunc = func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) {
if _, err := prompt.PromptTouch(); err != nil {
return nil, "", trace.Wrap(err)
}

if tc.expectResp.GetWebauthn() == nil {
<-ctx.Done()
return nil, "", trace.Wrap(ctx.Err())
}
if tc.expectResp.GetWebauthn() == nil {
<-ctx.Done()
return nil, "", trace.Wrap(ctx.Err())
}

return tc.expectResp, "", nil
return tc.expectResp, "", nil
}
}

buffer := make([]byte, 0, 100)
Expand Down
Loading

0 comments on commit 7ad25b5

Please sign in to comment.