From 8829746a51dbb30861f5c0107cb955a39e9ea88b Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 23 Oct 2024 08:23:56 -0500 Subject: [PATCH 1/2] Rename useWebAuthn to useMfa and handle SSO challenges (#47819) * Rename useWebauthn to useMfa * add ssoChallenge to mfa requests/responses * Update callsites and tests --- .../teleport/src/Account/Account.test.tsx | 16 +- .../DocumentKubeExec/DocumentKubeExec.tsx | 14 +- .../src/Console/DocumentSsh/DocumentSsh.tsx | 16 +- .../DesktopSession/DesktopSession.story.tsx | 18 +-- .../src/DesktopSession/DesktopSession.tsx | 27 ++-- .../src/DesktopSession/useDesktopSession.tsx | 6 +- .../AuthnDialog/AuthnDialog.story.tsx | 9 +- .../components/AuthnDialog/AuthnDialog.tsx | 35 +++-- ...uthnSender.ts => EventEmitterMfaSender.ts} | 4 +- web/packages/teleport/src/lib/tdp/client.ts | 4 +- web/packages/teleport/src/lib/term/tty.ts | 4 +- web/packages/teleport/src/lib/useMfa.ts | 148 ++++++++++++++++++ web/packages/teleport/src/lib/useWebAuthn.ts | 115 -------------- .../teleport/src/services/auth/makeMfa.ts | 13 +- .../teleport/src/services/auth/types.ts | 6 + 15 files changed, 235 insertions(+), 200 deletions(-) rename web/packages/teleport/src/lib/{EventEmitterWebAuthnSender.ts => EventEmitterMfaSender.ts} (91%) create mode 100644 web/packages/teleport/src/lib/useMfa.ts delete mode 100644 web/packages/teleport/src/lib/useWebAuthn.ts diff --git a/web/packages/teleport/src/Account/Account.test.tsx b/web/packages/teleport/src/Account/Account.test.tsx index 6fb23549a0e36..7dcf86f471adb 100644 --- a/web/packages/teleport/src/Account/Account.test.tsx +++ b/web/packages/teleport/src/Account/Account.test.tsx @@ -243,9 +243,11 @@ test('adding an MFA device', async () => { const user = userEvent.setup(); const ctx = createTeleportContext(); jest.spyOn(ctx.mfaService, 'fetchDevices').mockResolvedValue([testPasskey]); - jest - .spyOn(auth, 'getChallenge') - .mockResolvedValue({ webauthnPublicKey: null, totpChallenge: true }); + jest.spyOn(auth, 'getChallenge').mockResolvedValue({ + webauthnPublicKey: null, + totpChallenge: true, + ssoChallenge: null, + }); jest .spyOn(auth, 'createNewWebAuthnDevice') .mockResolvedValueOnce(dummyCredential); @@ -325,9 +327,11 @@ test('removing an MFA method', async () => { const user = userEvent.setup(); const ctx = createTeleportContext(); jest.spyOn(ctx.mfaService, 'fetchDevices').mockResolvedValue([testMfaMethod]); - jest - .spyOn(auth, 'getChallenge') - .mockResolvedValue({ webauthnPublicKey: null, totpChallenge: false }); + jest.spyOn(auth, 'getChallenge').mockResolvedValue({ + webauthnPublicKey: null, + totpChallenge: false, + ssoChallenge: null, + }); jest .spyOn(auth, 'createPrivilegeTokenWithWebauthn') .mockResolvedValueOnce('webauthn-privilege-token'); diff --git a/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx b/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx index 1589c6ef7d347..3b405e034f04c 100644 --- a/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx +++ b/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx @@ -22,7 +22,7 @@ import { Box, Indicator } from 'design'; import * as stores from 'teleport/Console/stores/types'; import { Terminal, TerminalRef } from 'teleport/Console/DocumentSsh/Terminal'; -import useWebAuthn from 'teleport/lib/useWebAuthn'; +import { useMfa } from 'teleport/lib/useMfa'; import useKubeExecSession from 'teleport/Console/DocumentKubeExec/useKubeExecSession'; import Document from 'teleport/Console/Document'; @@ -39,11 +39,11 @@ export default function DocumentKubeExec({ doc, visible }: Props) { const terminalRef = useRef(); const { tty, status, closeDocument, sendKubeExecData } = useKubeExecSession(doc); - const webauthn = useWebAuthn(tty); + const mfa = useMfa(tty); useEffect(() => { // when switching tabs or closing tabs, focus on visible terminal terminalRef.current?.focus(); - }, [visible, webauthn.requested]); + }, [visible, mfa.requested]); const theme = useTheme(); const terminal = ( @@ -63,13 +63,7 @@ export default function DocumentKubeExec({ doc, visible }: Props) { )} - {webauthn.requested && ( - - )} + {mfa.requested && } {status === 'waiting-for-exec-data' && ( diff --git a/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx b/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx index eb2720d7f012e..aacafdc35808a 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx @@ -31,7 +31,7 @@ import { import * as stores from 'teleport/Console/stores'; import AuthnDialog from 'teleport/components/AuthnDialog'; -import useWebAuthn from 'teleport/lib/useWebAuthn'; +import { useMfa } from 'teleport/lib/useMfa'; import Document from '../Document'; @@ -50,13 +50,13 @@ export default function DocumentSshWrapper(props: PropTypes) { function DocumentSsh({ doc, visible }: PropTypes) { const terminalRef = useRef(); const { tty, status, closeDocument, session } = useSshSession(doc); - const webauthn = useWebAuthn(tty); + const mfa = useMfa(tty); const { getMfaResponseAttempt, getDownloader, getUploader, fileTransferRequests, - } = useFileTransfer(tty, session, doc, webauthn.addMfaToScpUrls); + } = useFileTransfer(tty, session, doc, mfa.addMfaToScpUrls); const theme = useTheme(); function handleCloseFileTransfer() { @@ -70,7 +70,7 @@ function DocumentSsh({ doc, visible }: PropTypes) { useEffect(() => { // when switching tabs or closing tabs, focus on visible terminal terminalRef.current?.focus(); - }, [visible, webauthn.requested]); + }, [visible, mfa.requested]); const terminal = ( )} - {webauthn.requested && ( - - )} + {mfa.requested && } {status === 'initialized' && terminal} {}, clientOnClipboardData: async () => {}, setTdpConnection: () => {}, - webauthn: { - errorText: '', - requested: false, - authenticate: () => {}, - setState: () => {}, - addMfaToScpUrls: false, - }, + mfa: makeDefaultMfaState(), showAnotherSessionActiveDialog: false, setShowAnotherSessionActiveDialog: () => {}, alerts: [], @@ -265,12 +260,15 @@ export const WebAuthnPrompt = () => ( writeState: 'granted', }} wsConnection={{ status: 'open' }} - webauthn={{ + mfa={{ errorText: '', requested: true, - authenticate: () => {}, - setState: () => {}, + setErrorText: () => null, addMfaToScpUrls: false, + onWebauthnAuthenticate: () => null, + onSsoAuthenticate: () => null, + webauthnPublicKey: null, + ssoChallenge: null, }} /> ); diff --git a/web/packages/teleport/src/DesktopSession/DesktopSession.tsx b/web/packages/teleport/src/DesktopSession/DesktopSession.tsx index 66a66825a209e..b1f188f2997c9 100644 --- a/web/packages/teleport/src/DesktopSession/DesktopSession.tsx +++ b/web/packages/teleport/src/DesktopSession/DesktopSession.tsx @@ -39,7 +39,7 @@ import useDesktopSession, { import TopBar from './TopBar'; import type { State, WebsocketAttempt } from './useDesktopSession'; -import type { WebAuthnState } from 'teleport/lib/useWebAuthn'; +import type { MfaState } from 'teleport/lib/useMfa'; export function DesktopSessionContainer() { const state = useDesktopSession(); @@ -54,7 +54,7 @@ declare global { export function DesktopSession(props: State) { const { - webauthn, + mfa, tdpClient, username, hostname, @@ -105,7 +105,7 @@ export function DesktopSession(props: State) { tdpConnection, wsConnection, showAnotherSessionActiveDialog, - webauthn + mfa ) ); }, [ @@ -113,7 +113,7 @@ export function DesktopSession(props: State) { tdpConnection, wsConnection, showAnotherSessionActiveDialog, - webauthn, + mfa, ]); return ( @@ -144,7 +144,7 @@ export function DesktopSession(props: State) { {screenState.screen === 'anotherSessionActive' && ( )} - {screenState.screen === 'mfa' && } + {screenState.screen === 'mfa' && } {screenState.screen === 'alert dialog' && ( )} @@ -181,20 +181,15 @@ export function DesktopSession(props: State) { ); } -const MfaDialog = ({ webauthn }: { webauthn: WebAuthnState }) => { +const MfaDialog = ({ mfa }: { mfa: MfaState }) => { return ( { - webauthn.setState(prevState => { - return { - ...prevState, - errorText: - 'This session requires multi factor authentication to continue. Please hit "Retry" and follow the prompts given by your browser to complete authentication.', - }; - }); + mfa.setErrorText( + 'This session requires multi factor authentication to continue. Please hit "Retry" and follow the prompts given by your browser to complete authentication.' + ); }} - errorText={webauthn.errorText} /> ); }; @@ -282,7 +277,7 @@ const nextScreenState = ( tdpConnection: Attempt, wsConnection: WebsocketAttempt, showAnotherSessionActiveDialog: boolean, - webauthn: WebAuthnState + webauthn: MfaState ): ScreenState => { // We always want to show the user the first alert that caused the session to fail/end, // so if we're already showing an alert, don't change the screen. diff --git a/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx b/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx index a49e6d4a268fc..1f642d38d8d96 100644 --- a/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx +++ b/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx @@ -22,7 +22,7 @@ import { useParams } from 'react-router'; import useAttempt from 'shared/hooks/useAttemptNext'; import { ButtonState } from 'teleport/lib/tdp'; -import useWebAuthn from 'teleport/lib/useWebAuthn'; +import { useMfa } from 'teleport/lib/useMfa'; import desktopService from 'teleport/services/desktops'; import userService from 'teleport/services/user'; @@ -130,7 +130,7 @@ export default function useDesktopSession() { }); const tdpClient = clientCanvasProps.tdpClient; - const webauthn = useWebAuthn(tdpClient); + const mfa = useMfa(tdpClient); const onShareDirectory = () => { try { @@ -205,7 +205,7 @@ export default function useDesktopSession() { fetchAttempt, tdpConnection, wsConnection, - webauthn, + mfa, setTdpConnection, showAnotherSessionActiveDialog, setShowAnotherSessionActiveDialog, diff --git a/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.story.tsx b/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.story.tsx index 73600b5a7fb1c..8ec5592c47a0c 100644 --- a/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.story.tsx +++ b/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.story.tsx @@ -18,6 +18,8 @@ import React from 'react'; +import { makeDefaultMfaState } from 'teleport/lib/useMfa'; + import AuthnDialog, { Props } from './AuthnDialog'; export default { @@ -26,12 +28,9 @@ export default { export const Loaded = () => ; -export const Error = () => ( - -); +export const Error = () => ; const props: Props = { - onContinue: () => null, + mfa: makeDefaultMfaState(), onCancel: () => null, - errorText: '', }; diff --git a/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.tsx b/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.tsx index a8b5cd532a1bf..05685c0d6a3eb 100644 --- a/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.tsx +++ b/web/packages/teleport/src/components/AuthnDialog/AuthnDialog.tsx @@ -18,48 +18,51 @@ import React from 'react'; import Dialog, { - DialogFooter, DialogHeader, DialogTitle, DialogContent, } from 'design/Dialog'; import { Danger } from 'design/Alert'; -import { Text, ButtonPrimary, ButtonSecondary } from 'design'; +import { Text, ButtonPrimary, ButtonSecondary, Flex } from 'design'; -export default function AuthnDialog({ - onContinue, - onCancel, - errorText, -}: Props) { +import { MfaState } from 'teleport/lib/useMfa'; + +export default function AuthnDialog({ mfa, onCancel }: Props) { return ( - ({ width: '400px' })} open={true}> + ({ width: '500px' })} open={true}> Multi-factor authentication - {errorText && ( + {mfa.errorText && ( - {errorText} + {mfa.errorText} )} Re-enter your multi-factor authentication in the browser to continue. - - - {errorText ? 'Retry' : 'OK'} + + {/* TODO (avatus) this will eventually be conditionally rendered based on what + type of challenges exist. For now, its only webauthn. */} + + {mfa.errorText ? 'Retry' : 'OK'} Cancel - + ); } export type Props = { - onContinue: () => void; + mfa: MfaState; onCancel: () => void; - errorText: string; }; diff --git a/web/packages/teleport/src/lib/EventEmitterWebAuthnSender.ts b/web/packages/teleport/src/lib/EventEmitterMfaSender.ts similarity index 91% rename from web/packages/teleport/src/lib/EventEmitterWebAuthnSender.ts rename to web/packages/teleport/src/lib/EventEmitterMfaSender.ts index 834746c866bf6..68eae3367f6ea 100644 --- a/web/packages/teleport/src/lib/EventEmitterWebAuthnSender.ts +++ b/web/packages/teleport/src/lib/EventEmitterMfaSender.ts @@ -20,7 +20,7 @@ import { EventEmitter } from 'events'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; -class EventEmitterWebAuthnSender extends EventEmitter { +class EventEmitterMfaSender extends EventEmitter { constructor() { super(); } @@ -31,4 +31,4 @@ class EventEmitterWebAuthnSender extends EventEmitter { } } -export { EventEmitterWebAuthnSender }; +export { EventEmitterMfaSender }; diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index 098cb9d824fa6..6f000b083d820 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -24,7 +24,7 @@ import init, { } from 'teleport/ironrdp/pkg/ironrdp'; import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; -import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { EventEmitterMfaSender } from 'teleport/lib/EventEmitterMfaSender'; import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import Codec, { @@ -93,7 +93,7 @@ export enum LogType { // sending client commands, and receiving and processing server messages. Its creator is responsible for // ensuring the websocket gets closed and all of its event listeners cleaned up when it is no longer in use. // For convenience, this can be done in one fell swoop by calling Client.shutdown(). -export default class Client extends EventEmitterWebAuthnSender { +export default class Client extends EventEmitterMfaSender { protected codec: Codec; protected socket: AuthenticatedWebSocket | undefined; private socketAddr: string; diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 6bd9014234323..2eb11957b8fbd 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -18,7 +18,7 @@ import Logger from 'shared/libs/logger'; -import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { EventEmitterMfaSender } from 'teleport/lib/EventEmitterMfaSender'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; @@ -31,7 +31,7 @@ const defaultOptions = { buffered: true, }; -class Tty extends EventEmitterWebAuthnSender { +class Tty extends EventEmitterMfaSender { socket = null; _buffered = true; diff --git a/web/packages/teleport/src/lib/useMfa.ts b/web/packages/teleport/src/lib/useMfa.ts new file mode 100644 index 0000000000000..8d55cf4c73f75 --- /dev/null +++ b/web/packages/teleport/src/lib/useMfa.ts @@ -0,0 +1,148 @@ +/** + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { useState, useEffect, useCallback } from 'react'; + +import { EventEmitterMfaSender } from 'teleport/lib/EventEmitterMfaSender'; +import { TermEvent } from 'teleport/lib/term/enums'; +import { + makeMfaAuthenticateChallenge, + makeWebauthnAssertionResponse, + SSOChallenge, +} from 'teleport/services/auth'; + +export function useMfa(emitterSender: EventEmitterMfaSender): MfaState { + const [state, setState] = useState<{ + errorText: string; + addMfaToScpUrls: boolean; + webauthnPublicKey: PublicKeyCredentialRequestOptions; + ssoChallenge: SSOChallenge; + totpChallenge: boolean; + }>({ + addMfaToScpUrls: false, + errorText: '', + webauthnPublicKey: null, + ssoChallenge: null, + totpChallenge: false, + }); + + // TODO (avatus), this is stubbed for types but will not be called + // until SSO as MFA backend is in. + function onSsoAuthenticate() { + // eslint-disable-next-line no-console + console.error('not yet implemented'); + } + + function onWebauthnAuthenticate() { + if (!window.PublicKeyCredential) { + const errorText = + 'This browser does not support WebAuthn required for hardware tokens, \ + please try the latest version of Chrome, Firefox or Safari.'; + + setState({ + ...state, + errorText, + }); + return; + } + + navigator.credentials + .get({ publicKey: state.webauthnPublicKey }) + .then(res => { + setState(prevState => ({ + ...prevState, + errorText: '', + webauthnPublicKey: null, + })); + const credential = makeWebauthnAssertionResponse(res); + emitterSender.sendWebAuthn(credential); + }) + .catch((err: Error) => { + setErrorText(err.message); + }); + } + + const onChallenge = useCallback(challengeJson => { + const { webauthnPublicKey, ssoChallenge, totpChallenge } = + makeMfaAuthenticateChallenge(challengeJson); + + setState(prevState => ({ + ...prevState, + ssoChallenge, + webauthnPublicKey, + totpChallenge, + })); + }, []); + + useEffect(() => { + if (emitterSender) { + emitterSender.on(TermEvent.WEBAUTHN_CHALLENGE, onChallenge); + + return () => { + emitterSender.removeListener(TermEvent.WEBAUTHN_CHALLENGE, onChallenge); + }; + } + }, [emitterSender, onChallenge]); + + function setErrorText(newErrorText: string) { + setState(prevState => ({ ...prevState, errorText: newErrorText })); + } + + // if any challenge exists, requested is true + const requested = !!( + state.webauthnPublicKey || + state.totpChallenge || + state.ssoChallenge + ); + + return { + requested, + onWebauthnAuthenticate, + onSsoAuthenticate, + addMfaToScpUrls: state.addMfaToScpUrls, + setErrorText, + errorText: state.errorText, + webauthnPublicKey: state.webauthnPublicKey, + ssoChallenge: state.ssoChallenge, + }; +} + +export type MfaState = { + onWebauthnAuthenticate: () => void; + onSsoAuthenticate: () => void; + setErrorText: (errorText: string) => void; + errorText: string; + requested: boolean; + addMfaToScpUrls: boolean; + webauthnPublicKey: PublicKeyCredentialRequestOptions; + ssoChallenge: SSOChallenge; +}; + +// used for testing +export function makeDefaultMfaState(): MfaState { + return { + onWebauthnAuthenticate: () => null, + onSsoAuthenticate: () => null, + setErrorText: () => null, + errorText: '', + requested: false, + addMfaToScpUrls: false, + webauthnPublicKey: null, + ssoChallenge: null, + }; +} diff --git a/web/packages/teleport/src/lib/useWebAuthn.ts b/web/packages/teleport/src/lib/useWebAuthn.ts deleted file mode 100644 index 730065299ceed..0000000000000 --- a/web/packages/teleport/src/lib/useWebAuthn.ts +++ /dev/null @@ -1,115 +0,0 @@ -/** - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -import { useState, useEffect, Dispatch, SetStateAction } from 'react'; - -import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; -import { TermEvent } from 'teleport/lib/term/enums'; -import { - makeMfaAuthenticateChallenge, - makeWebauthnAssertionResponse, -} from 'teleport/services/auth'; - -export default function useWebAuthn( - emitterSender: EventEmitterWebAuthnSender -): WebAuthnState { - const [state, setState] = useState({ - addMfaToScpUrls: false, - requested: false, - errorText: '', - publicKey: null as PublicKeyCredentialRequestOptions, - }); - - function authenticate() { - if (!window.PublicKeyCredential) { - const errorText = - 'This browser does not support WebAuthn required for hardware tokens, \ - please try the latest version of Chrome, Firefox or Safari.'; - - setState({ - ...state, - errorText, - }); - return; - } - - navigator.credentials - .get({ publicKey: state.publicKey }) - .then(res => { - const credential = makeWebauthnAssertionResponse(res); - emitterSender.sendWebAuthn(credential); - - setState({ - ...state, - requested: false, - errorText: '', - }); - }) - .catch((err: Error) => { - setState({ - ...state, - errorText: err.message, - }); - }); - } - - const onChallenge = challengeJson => { - const challenge = JSON.parse(challengeJson); - const publicKey = makeMfaAuthenticateChallenge(challenge).webauthnPublicKey; - - setState({ - ...state, - requested: true, - addMfaToScpUrls: true, - publicKey, - }); - }; - - useEffect(() => { - if (emitterSender) { - emitterSender.on(TermEvent.WEBAUTHN_CHALLENGE, onChallenge); - - return () => { - emitterSender.removeListener(TermEvent.WEBAUTHN_CHALLENGE, onChallenge); - }; - } - }, [emitterSender]); - - return { - errorText: state.errorText, - requested: state.requested, - authenticate, - setState, - addMfaToScpUrls: state.addMfaToScpUrls, - }; -} - -export type WebAuthnState = { - errorText: string; - requested: boolean; - authenticate: () => void; - setState: Dispatch< - SetStateAction<{ - addMfaToScpUrls: boolean; - requested: boolean; - errorText: string; - publicKey: PublicKeyCredentialRequestOptions; - }> - >; - addMfaToScpUrls: boolean; -}; diff --git a/web/packages/teleport/src/services/auth/makeMfa.ts b/web/packages/teleport/src/services/auth/makeMfa.ts index 0637967483911..506cca4a874c7 100644 --- a/web/packages/teleport/src/services/auth/makeMfa.ts +++ b/web/packages/teleport/src/services/auth/makeMfa.ts @@ -50,12 +50,15 @@ export function makeMfaRegistrationChallenge(json): MfaRegistrationChallenge { } // makeMfaAuthenticateChallenge formats fetched authenticate challenge JSON. -// Webauthn challange contains Base64URL(byte) fields that needs to +// Webauthn challenge contains Base64URL(byte) fields that needs to // be converted to ArrayBuffer expected by navigator.credentials.get: // - challenge // - allowCredentials[i].id export function makeMfaAuthenticateChallenge(json): MfaAuthenticateChallenge { - const webauthnPublicKey = json.webauthn_challenge?.publicKey; + const challenge = typeof json === 'string' ? JSON.parse(json) : json; + const { sso_challenge, webauthn_challenge } = challenge; + + const webauthnPublicKey = webauthn_challenge?.publicKey; if (webauthnPublicKey) { const challenge = webauthnPublicKey.challenge || ''; const allowCredentials = webauthnPublicKey.allowCredentials || []; @@ -70,6 +73,12 @@ export function makeMfaAuthenticateChallenge(json): MfaAuthenticateChallenge { } return { + ssoChallenge: sso_challenge + ? { + redirectUrl: sso_challenge.redirect_url, + requestId: sso_challenge.request_id, + } + : null, totpChallenge: json.totp_challenge, webauthnPublicKey: webauthnPublicKey, }; diff --git a/web/packages/teleport/src/services/auth/types.ts b/web/packages/teleport/src/services/auth/types.ts index 11057cd185645..170d4eedee272 100644 --- a/web/packages/teleport/src/services/auth/types.ts +++ b/web/packages/teleport/src/services/auth/types.ts @@ -32,7 +32,13 @@ export type AuthnChallengeRequest = { userCred: UserCredentials; }; +export type SSOChallenge = { + redirectUrl: string; + requestId: string; +}; + export type MfaAuthenticateChallenge = { + ssoChallenge: SSOChallenge; totpChallenge: boolean; webauthnPublicKey: PublicKeyCredentialRequestOptions; }; From eda3671595b6507c695dd460ae9d659f0505654a Mon Sep 17 00:00:00 2001 From: Trent Clarke Date: Thu, 24 Oct 2024 00:27:56 +1100 Subject: [PATCH 2/2] Moves and exposes the AWS OIDC credentials cache (#47840) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Moves and exposes the AWS OIDC credentials cache The IdentityCenter integration users OIDC to authenticate with AWS, and so will re-use the existing OIDC credential caching code used by the external audit storage package. This change - extracts the credential cache from the `externalauditstorage` package, - moves it to the `awsoidc` package to indicate that is generally useful, not just for storage access), and - makes it public. This patch also copies the applicable cache tests from `externalauditstorage`. The credential cache tests in `externalauditstorage` have been preserved because they also test backwards compatibility with AWS SDK v1 credential provider, which the new tests do not. * Test fixup * Linter fixups * Remove ttlValuer * Apply suggestions from code review Co-authored-by: Marek Smoliński --------- Co-authored-by: Marek Smoliński --- lib/integrations/awsoidc/credentialscache.go | 284 ++++++++++++++++++ .../awsoidc/credentialscache_test.go | 226 ++++++++++++++ .../externalauditstorage/configurator.go | 206 ++----------- .../externalauditstorage/configurator_test.go | 2 +- lib/service/service_test.go | 2 +- 5 files changed, 530 insertions(+), 190 deletions(-) create mode 100644 lib/integrations/awsoidc/credentialscache.go create mode 100644 lib/integrations/awsoidc/credentialscache_test.go diff --git a/lib/integrations/awsoidc/credentialscache.go b/lib/integrations/awsoidc/credentialscache.go new file mode 100644 index 0000000000000..1d1ddffe3bf1c --- /dev/null +++ b/lib/integrations/awsoidc/credentialscache.go @@ -0,0 +1,284 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package awsoidc + +import ( + "context" + "errors" + "log/slog" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport" +) + +const ( + // TokenLifetime is the lifetime of OIDC tokens used by the + // ExternalAuditStorage service with the AWS OIDC integration. + TokenLifetime = time.Hour + + refreshBeforeExpirationPeriod = 15 * time.Minute + refreshCheckInterval = 30 * time.Second + retrieveTimeout = 30 * time.Second +) + +// GenerateOIDCTokenFn is a function that should return a valid, signed JWT for +// authenticating to AWS via OIDC. +type GenerateOIDCTokenFn func(ctx context.Context, integration string) (string, error) + +type credsOrErr struct { + creds aws.Credentials + err error +} + +// CredentialsCache is used to store and refresh AWS credentials used with +// AWS OIDC integration. +// +// Credentials are valid for 1h, but they cannot be refreshed if Proxy is down, +// so we attempt to refresh the credentials early and retry on failure. +type CredentialsCache struct { + log *slog.Logger + + roleARN arn.ARN + integration string + + // generateOIDCTokenFn is dynamically set after auth is initialized. + generateOIDCTokenFn GenerateOIDCTokenFn + + // initialized communicates (via closing channel) that generateOIDCTokenFn is set. + initialized chan struct{} + closeInitialized func() + + // gotFirstCredsOrErr communicates (via closing channel) that the first + // credsOrErr has been set. + gotFirstCredsOrErr chan struct{} + closeGotFirstCredsOrErr func() + + credsOrErr credsOrErr + credsOrErrMu sync.RWMutex + + stsClient stscreds.AssumeRoleWithWebIdentityAPIClient + clock clockwork.Clock +} + +type CredentialsCacheOptions struct { + // Integration is the name of the Teleport OIDC integration to use + Integration string + + // RoleARN is the ARN of the role to assume once authenticated + RoleARN arn.ARN + + // STSClient is the AWS sts client implementation to use when communicating + // with AWS + STSClient stscreds.AssumeRoleWithWebIdentityAPIClient + + // Log is the logger to use. A default will be supplied if no logger is + // explicitly set + Log *slog.Logger + + // Clock is the clock to use. A default system clock will be provided if + // none is supplied. + Clock clockwork.Clock +} + +func (opts *CredentialsCacheOptions) CheckAndSetDefaults() error { + if opts.STSClient == nil { + return trace.BadParameter("stsClient must be provided") + } + + if opts.Log == nil { + opts.Log = slog.Default().With(teleport.ComponentKey, "AWS-OIDC-CredentialCache") + } + + if opts.Clock == nil { + opts.Clock = clockwork.NewRealClock() + } + + return nil +} + +var errNotReady = errors.New("ExternalAuditStorage: credential cache not yet initialized") + +func NewCredentialsCache(options CredentialsCacheOptions) (*CredentialsCache, error) { + if err := options.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err, "creating credentials cache") + } + + initialized := make(chan struct{}) + gotFirstCredsOrErr := make(chan struct{}) + + return &CredentialsCache{ + roleARN: options.RoleARN, + integration: options.Integration, + log: options.Log.With("integration", options.Integration), + initialized: initialized, + closeInitialized: sync.OnceFunc(func() { close(initialized) }), + gotFirstCredsOrErr: gotFirstCredsOrErr, + closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), + credsOrErr: credsOrErr{err: errNotReady}, + clock: options.Clock, + stsClient: options.STSClient, + }, nil +} + +func (cc *CredentialsCache) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { + cc.generateOIDCTokenFn = fn + cc.closeInitialized() +} + +// Retrieve implements [aws.CredentialsProvider] and returns the latest cached +// credentials, or an error if no credentials have been generated yet or the +// last generated credentials have expired. +func (cc *CredentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) { + cc.credsOrErrMu.RLock() + defer cc.credsOrErrMu.RUnlock() + + if cc.credsOrErr.err != nil { + cc.log.WarnContext(ctx, "Returning error to AWS client", errorValue(cc.credsOrErr.err)) + } + + return cc.credsOrErr.creds, cc.credsOrErr.err +} + +func (cc *CredentialsCache) Run(ctx context.Context) { + // Wait for initialized signal before running loop. + select { + case <-cc.initialized: + case <-ctx.Done(): + cc.log.DebugContext(ctx, "Context canceled before initialized.") + return + } + + cc.refreshIfNeeded(ctx) + + ticker := cc.clock.NewTicker(refreshCheckInterval) + defer ticker.Stop() + for { + select { + case <-ticker.Chan(): + cc.refreshIfNeeded(ctx) + case <-ctx.Done(): + cc.log.DebugContext(ctx, "Context canceled, stopping refresh loop.") + return + } + } +} + +func (cc *CredentialsCache) refreshIfNeeded(ctx context.Context) { + credsFromCache, err := cc.Retrieve(ctx) + if err == nil && + credsFromCache.HasKeys() && + cc.clock.Now().Add(refreshBeforeExpirationPeriod).Before(credsFromCache.Expires) { + // No need to refresh, credentials in cache are still valid for longer + // than refreshBeforeExpirationPeriod + return + } + cc.log.DebugContext(ctx, "Refreshing credentials.") + + creds, err := cc.refresh(ctx) + if err != nil { + cc.log.WarnContext(ctx, "Failed to retrieve new credentials", errorValue(err)) + now := cc.clock.Now() + // If we were not able to refresh, check if existing credentials in + // cache are still valid. If yes, just log debug, it will be retried on + // next interval check. + if credsFromCache.HasKeys() && now.Before(credsFromCache.Expires) { + cc.log.DebugContext(ctx, "Continuing to use existing credentials", + slog.Duration( + "ttl", + credsFromCache.Expires.Sub(now).Round(time.Second))) + return + } + // If existing creds are expired, update cached error. + cc.log.ErrorContext(ctx, "Setting cached error", "error", err) + cc.setCredsOrErr(credsOrErr{err: trace.Wrap(err)}) + return + } + + // Refresh went well, update cached creds. + cc.setCredsOrErr(credsOrErr{creds: creds}) + cc.log.DebugContext(ctx, "Successfully refreshed credentials", + slog.Time("expires", creds.Expires)) +} + +func (cc *CredentialsCache) setCredsOrErr(coe credsOrErr) { + cc.credsOrErrMu.Lock() + defer cc.credsOrErrMu.Unlock() + cc.credsOrErr = coe + cc.closeGotFirstCredsOrErr() +} + +func (cc *CredentialsCache) refresh(ctx context.Context) (aws.Credentials, error) { + cc.log.InfoContext(ctx, "Refreshing AWS credentials") + defer cc.log.InfoContext(ctx, "Exiting AWS credentials refresh") + + cc.log.InfoContext(ctx, "Generating Token") + oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration) + if err != nil { + cc.log.ErrorContext(ctx, "Token generation failed", errorValue(err)) + return aws.Credentials{}, trace.Wrap(err) + } + + roleProvider := stscreds.NewWebIdentityRoleProvider( + cc.stsClient, + cc.roleARN.String(), + identityToken(oidcToken), + func(wiro *stscreds.WebIdentityRoleOptions) { + wiro.Duration = TokenLifetime + }, + ) + + ctx, cancel := context.WithTimeout(ctx, retrieveTimeout) + defer cancel() + + cc.log.InfoContext(ctx, "Retrieving AWS role credentials") + + creds, err := roleProvider.Retrieve(ctx) + if err != nil { + cc.log.ErrorContext(ctx, "Role retrieval failed", errorValue(err)) + } + + return creds, trace.Wrap(err) +} + +func (cc *CredentialsCache) WaitForFirstCredsOrErr(ctx context.Context) { + cc.log.InfoContext(ctx, "Entering wait on first credential refresh") + defer cc.log.InfoContext(ctx, "Exiting wait on first credential refresh") + + select { + case <-ctx.Done(): + case <-cc.gotFirstCredsOrErr: + } +} + +// identityToken is an implementation of [stscreds.IdentityTokenRetriever] for returning a static token. +type identityToken string + +// GetIdentityToken returns the token configured. +func (j identityToken) GetIdentityToken() ([]byte, error) { + return []byte(j), nil +} + +func errorValue(v error) slog.Attr { + return slog.Any("error", v) +} diff --git a/lib/integrations/awsoidc/credentialscache_test.go b/lib/integrations/awsoidc/credentialscache_test.go new file mode 100644 index 0000000000000..cc997758f70be --- /dev/null +++ b/lib/integrations/awsoidc/credentialscache_test.go @@ -0,0 +1,226 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package awsoidc + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/entitlements" + "github.com/gravitational/teleport/lib/modules" +) + +type fakeSTSClient struct { + clock clockwork.Clock + err error + sync.Mutex +} + +func (f *fakeSTSClient) setError(err error) { + f.Lock() + f.err = err + f.Unlock() +} + +func (f *fakeSTSClient) getError() error { + f.Lock() + defer f.Unlock() + return f.err +} + +func (f *fakeSTSClient) AssumeRoleWithWebIdentity(ctx context.Context, params *sts.AssumeRoleWithWebIdentityInput, optFns ...func(*sts.Options)) (*sts.AssumeRoleWithWebIdentityOutput, error) { + if err := f.getError(); err != nil { + return nil, err + } + + expiration := f.clock.Now().Add(time.Second * time.Duration(*params.DurationSeconds)) + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &ststypes.Credentials{ + Expiration: &expiration, + // These are example values taken from https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + SessionToken: aws.String("AQoDYXdzEE0a8ANXXXXXXXXNO1ewxE5TijQyp+IEXAMPLE"), + SecretAccessKey: aws.String("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY"), + AccessKeyId: aws.String("ASgeIAIOSFODNN7EXAMPLE"), + }, + }, nil +} + +func TestCredentialsCache(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + modules.SetTestModules(t, &modules.TestModules{ + TestFeatures: modules.Features{ + Cloud: true, + Entitlements: map[entitlements.EntitlementKind]modules.EntitlementInfo{ + entitlements.ExternalAuditStorage: {Enabled: true}, + }, + }, + }) + + // GIVEN a configured and running credential cache... + clock := clockwork.NewFakeClock() + stsClient := &fakeSTSClient{ + clock: clock, + } + cacheUnderTest, err := NewCredentialsCache(CredentialsCacheOptions{ + STSClient: stsClient, + Integration: "test", + Clock: clock, + }) + require.NoError(t, err) + require.NotNil(t, cacheUnderTest) + go cacheUnderTest.Run(ctx) + + advanceClock := func(d time.Duration) { + // Wait for the run loop to actually wait on the clock ticker before advancing. If we advance before + // the loop waits on the ticker, it may never tick. + clock.BlockUntil(1) + clock.Advance(d) + } + + // Set the GenerateOIDCTokenFn to a dumb faked function. + cacheUnderTest.SetGenerateOIDCTokenFn( + func(ctx context.Context, integration string) (string, error) { + return uuid.NewString(), nil + }) + + checkRetrieveCredentials := func(t require.TestingT, expectErr error) { + _, err := cacheUnderTest.Retrieve(ctx) + assert.ErrorIs(t, err, expectErr) + } + + checkRetrieveCredentialsWithExpiry := func(t require.TestingT, expectExpiry time.Time) { + creds, err := cacheUnderTest.Retrieve(ctx) + assert.NoError(t, err) + if err == nil { + assert.WithinDuration(t, expectExpiry, creds.Expires, time.Minute) + } + } + + const ( + // Using a longer wait time to avoid test flakes observed with 1s wait. + waitFor = 10 * time.Second + // We're using a short sleep (1ms) to allow the refresh loop goroutine to get scheduled. + // This keeps the test fast under normal conditions. If there's CPU starvation in CI, + // neither the test goroutine nor the refresh loop are likely getting scheduled often, + // so this shouldn't result in a busy loop. + tick = 1 * time.Millisecond + ) + + t.Run("Retrieve", func(t *testing.T) { + // Assert that credentials can be retrieved when everything is happy. + // EventuallyWithT is necessary to allow credentialsCache.run to be + // scheduled after SetGenerateOIDCTokenFn above. + initialCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + }, waitFor, tick) + }) + + t.Run("CachedCredsArePreservedOnError", func(t *testing.T) { + initialCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + }, waitFor, tick) + + // Assert that the good cached credentials are still used even if sts starts + // returning errors. + stsError := errors.New("test error") + stsClient.setError(stsError) + // Test immediately + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute before first refresh attempt + advanceClock(TokenLifetime - refreshBeforeExpirationPeriod - time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute after first refresh attempt + advanceClock(2 * time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + // Advance to 1 minute before credential expiry + advanceClock(refreshBeforeExpirationPeriod - 2*time.Minute) + checkRetrieveCredentialsWithExpiry(t, initialCredentialExpiry) + + // Advance 1 minute past the credential expiry and make sure we get the + // expected error. + advanceClock(2 * time.Minute) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentials(t, stsError) + }, waitFor, tick) + + // Fix STS and make sure we stop getting errors within refreshCheckInterval + stsClient.setError(nil) + advanceClock(refreshCheckInterval) + newCredentialExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + checkRetrieveCredentialsWithExpiry(t, newCredentialExpiry) + }, waitFor, tick) + }) + + t.Run("WindowedErrors", func(t *testing.T) { + // Test a scenario where STS is returning errors in two different 10-minute windows: the first surrounding + // the expected cert refresh time, and the second surrounding the cert expiry time. + // In this case the credentials cache should refresh the certs somewhere between those two outages, and + // clients should never see an error retrieving credentials. + newCredentialExpiry := clock.Now().Add(TokenLifetime) + expectedRefreshTime := newCredentialExpiry.Add(-refreshBeforeExpirationPeriod) + credentialsUpdated := false + done := newCredentialExpiry.Add(10 * time.Minute) + stsError := errors.New("test error") + for clock.Now().Before(done) { + if clock.Now().Sub(expectedRefreshTime).Abs() < 5*time.Minute || + clock.Now().Sub(newCredentialExpiry).Abs() < 5*time.Minute { + // Within one of the 10-minute outage windows, make the STS client return errors. + stsClient.setError(stsError) + advanceClock(time.Minute) + } else { + // Not within an outage window, STS client should not return errors. + stsClient.setError(nil) + advanceClock(time.Minute) + + if !credentialsUpdated && clock.Now().After(expectedRefreshTime) { + // This is after the expected refresh time and not within an outage window, for the test to + // not be flaky we need to wait for the cache run loop to get a chance to refresh the + // credentials. + expectedExpiry := clock.Now().Add(TokenLifetime) + require.EventuallyWithT(t, func(t *assert.CollectT) { + creds, err := cacheUnderTest.Retrieve(ctx) + assert.NoError(t, err) + assert.WithinDuration(t, expectedExpiry, creds.Expires, 2*time.Minute) + }, waitFor, tick) + credentialsUpdated = true + } + } + + // Assert that there is never an error getting credentials. + checkRetrieveCredentials(t, nil) + } + }) +} diff --git a/lib/integrations/externalauditstorage/configurator.go b/lib/integrations/externalauditstorage/configurator.go index 050298f9e13e0..66cea204a57cc 100644 --- a/lib/integrations/externalauditstorage/configurator.go +++ b/lib/integrations/externalauditstorage/configurator.go @@ -20,23 +20,21 @@ package externalauditstorage import ( "context" - "errors" - "sync" "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/externalauditstorage" "github.com/gravitational/teleport/entitlements" + "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" ) @@ -86,7 +84,7 @@ type Configurator struct { spec *externalauditstorage.ExternalAuditStorageSpec isUsed bool - credentialsCache *credentialsCache + credentialsCache *awsoidc.CredentialsCache } // Options holds options for the Configurator. @@ -202,7 +200,10 @@ func newConfigurator(ctx context.Context, spec *externalauditstorage.ExternalAud "ExternalAuditStorage: configured integration %q does not appear to be an AWS OIDC integration", oidcIntegrationName) } - awsRoleARN := awsOIDCSpec.RoleARN + awsRoleARN, err := arn.Parse(awsOIDCSpec.RoleARN) + if err != nil { + return nil, trace.Wrap(err, "AWS role is not a valid ARN") + } options := &Options{} for _, optFn := range optFns { @@ -212,11 +213,16 @@ func newConfigurator(ctx context.Context, spec *externalauditstorage.ExternalAud return nil, trace.Wrap(err) } - credentialsCache, err := newCredentialsCache(oidcIntegrationName, awsRoleARN, options) + credentialsCache, err := awsoidc.NewCredentialsCache(awsoidc.CredentialsCacheOptions{ + Integration: oidcIntegrationName, + RoleARN: awsRoleARN, + STSClient: options.stsClient, + Clock: options.clock, + }) if err != nil { return nil, trace.Wrap(err) } - go credentialsCache.run(ctx) + go credentialsCache.Run(ctx) // Draft configurator does not need to count errors or create cluster // alerts. @@ -245,13 +251,9 @@ func (c *Configurator) GetSpec() *externalauditstorage.ExternalAuditStorageSpec return c.spec } -// GenerateOIDCTokenFn is a function that should return a valid, signed JWT for -// authenticating to AWS via OIDC. -type GenerateOIDCTokenFn func(ctx context.Context, integration string) (string, error) - // SetGenerateOIDCTokenFn sets the source of OIDC tokens for this Configurator. -func (c *Configurator) SetGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { - c.credentialsCache.setGenerateOIDCTokenFn(fn) +func (c *Configurator) SetGenerateOIDCTokenFn(fn awsoidc.GenerateOIDCTokenFn) { + c.credentialsCache.SetGenerateOIDCTokenFn(fn) } // CredentialsProvider returns an aws.CredentialsProvider that can be used to @@ -274,185 +276,13 @@ func (p *Configurator) CredentialsProviderSDKV1() credentials.ProviderWithContex // credential providers won't return errors simply due to the cache not being // ready yet. func (p *Configurator) WaitForFirstCredentials(ctx context.Context) { - p.credentialsCache.waitForFirstCredsOrErr(ctx) -} - -// credentialsCache is used to store and refresh AWS credentials used with -// AWS OIDC integration. -// -// Credentials are valid for 1h, but they cannot be refreshed if Proxy is down, -// so we attempt to refresh the credentials early and retry on failure. -// -// credentialsCache is a dependency to both the s3 session uploader and the -// athena audit logger. They are both initialized before auth. However AWS -// credentials using OIDC integration can be obtained only after auth is -// initialized. That's why generateOIDCTokenFn is injected dynamically after -// auth is initialized. Before initialization, credentialsCache will return -// an error on any Retrieve call. -type credentialsCache struct { - log *logrus.Entry - - roleARN string - integration string - - // generateOIDCTokenFn is dynamically set after auth is initialized. - generateOIDCTokenFn GenerateOIDCTokenFn - - // initialized communicates (via closing channel) that generateOIDCTokenFn is set. - initialized chan struct{} - closeInitialized func() - - // gotFirstCredsOrErr communicates (via closing channel) that the first - // credsOrErr has been set. - gotFirstCredsOrErr chan struct{} - closeGotFirstCredsOrErr func() - - credsOrErr credsOrErr - credsOrErrMu sync.RWMutex - - stsClient stscreds.AssumeRoleWithWebIdentityAPIClient - clock clockwork.Clock -} - -type credsOrErr struct { - creds aws.Credentials - err error -} - -func newCredentialsCache(integration, roleARN string, options *Options) (*credentialsCache, error) { - initialized := make(chan struct{}) - gotFirstCredsOrErr := make(chan struct{}) - return &credentialsCache{ - roleARN: roleARN, - integration: integration, - log: logrus.WithField(teleport.ComponentKey, "ExternalAuditStorage.CredentialsCache"), - initialized: initialized, - closeInitialized: sync.OnceFunc(func() { close(initialized) }), - gotFirstCredsOrErr: gotFirstCredsOrErr, - closeGotFirstCredsOrErr: sync.OnceFunc(func() { close(gotFirstCredsOrErr) }), - credsOrErr: credsOrErr{ - err: errors.New("ExternalAuditStorage: credential cache not yet initialized"), - }, - clock: options.clock, - stsClient: options.stsClient, - }, nil -} - -func (cc *credentialsCache) setGenerateOIDCTokenFn(fn GenerateOIDCTokenFn) { - cc.generateOIDCTokenFn = fn - cc.closeInitialized() -} - -// Retrieve implements [aws.CredentialsProvider] and returns the latest cached -// credentials, or an error if no credentials have been generated yet or the -// last generated credentials have expired. -func (cc *credentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) { - cc.credsOrErrMu.RLock() - defer cc.credsOrErrMu.RUnlock() - return cc.credsOrErr.creds, cc.credsOrErr.err -} - -func (cc *credentialsCache) run(ctx context.Context) { - // Wait for initialized signal before running loop. - select { - case <-cc.initialized: - case <-ctx.Done(): - cc.log.Debug("Context canceled before initialized.") - return - } - - cc.refreshIfNeeded(ctx) - - ticker := cc.clock.NewTicker(refreshCheckInterval) - defer ticker.Stop() - for { - select { - case <-ticker.Chan(): - cc.refreshIfNeeded(ctx) - case <-ctx.Done(): - cc.log.Debugf("Context canceled, stopping refresh loop.") - return - } - } -} - -func (cc *credentialsCache) refreshIfNeeded(ctx context.Context) { - credsFromCache, err := cc.Retrieve(ctx) - if err == nil && - credsFromCache.HasKeys() && - cc.clock.Now().Add(refreshBeforeExpirationPeriod).Before(credsFromCache.Expires) { - // No need to refresh, credentials in cache are still valid for longer - // than refreshBeforeExpirationPeriod - return - } - cc.log.Debugf("Refreshing credentials.") - - creds, err := cc.refresh(ctx) - if err != nil { - cc.log.Warnf("Failed to retrieve new credentials: %v", err) - // If we were not able to refresh, check if existing credentials in cache are still valid. - // If yes, just log debug, it will be retried on next interval check. - if credsFromCache.HasKeys() && cc.clock.Now().Before(credsFromCache.Expires) { - cc.log.Debugf("Using existing credentials expiring in %s.", credsFromCache.Expires.Sub(cc.clock.Now()).Round(time.Second).String()) - return - } - // If existing creds are expired, update cached error. - cc.setCredsOrErr(credsOrErr{err: trace.Wrap(err)}) - return - } - // Refresh went well, update cached creds. - cc.setCredsOrErr(credsOrErr{creds: creds}) - cc.log.Debugf("Successfully refreshed credentials, new expiry at %v", creds.Expires) -} - -func (cc *credentialsCache) setCredsOrErr(coe credsOrErr) { - cc.credsOrErrMu.Lock() - defer cc.credsOrErrMu.Unlock() - cc.credsOrErr = coe - cc.closeGotFirstCredsOrErr() -} - -func (cc *credentialsCache) refresh(ctx context.Context) (aws.Credentials, error) { - oidcToken, err := cc.generateOIDCTokenFn(ctx, cc.integration) - if err != nil { - return aws.Credentials{}, trace.Wrap(err) - } - - roleProvider := stscreds.NewWebIdentityRoleProvider( - cc.stsClient, - cc.roleARN, - identityToken(oidcToken), - func(wiro *stscreds.WebIdentityRoleOptions) { - wiro.Duration = TokenLifetime - }, - ) - - ctx, cancel := context.WithTimeout(ctx, retrieveTimeout) - defer cancel() - - creds, err := roleProvider.Retrieve(ctx) - return creds, trace.Wrap(err) -} - -func (cc *credentialsCache) waitForFirstCredsOrErr(ctx context.Context) { - select { - case <-ctx.Done(): - case <-cc.gotFirstCredsOrErr: - } -} - -// identityToken is an implementation of [stscreds.IdentityTokenRetriever] for returning a static token. -type identityToken string - -// GetIdentityToken returns the token configured. -func (j identityToken) GetIdentityToken() ([]byte, error) { - return []byte(j), nil + p.credentialsCache.WaitForFirstCredsOrErr(ctx) } // v1Adapter wraps the credentialsCache to implement // [credentials.ProviderWithContext] used by aws-sdk-go (v1). type v1Adapter struct { - cc *credentialsCache + cc *awsoidc.CredentialsCache } var _ credentials.ProviderWithContext = (*v1Adapter)(nil) diff --git a/lib/integrations/externalauditstorage/configurator_test.go b/lib/integrations/externalauditstorage/configurator_test.go index abb1ce1425b9e..ba86e5f8e0c27 100644 --- a/lib/integrations/externalauditstorage/configurator_test.go +++ b/lib/integrations/externalauditstorage/configurator_test.go @@ -47,7 +47,7 @@ func testOIDCIntegration(t *testing.T) *types.IntegrationV1 { oidcIntegration, err := types.NewIntegrationAWSOIDC( types.Metadata{Name: "aws-integration-1"}, &types.AWSOIDCIntegrationSpecV1{ - RoleARN: "role1", + RoleARN: "arn:aws:iam::account:role/role1", }, ) require.NoError(t, err) diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 8b029bdab2ca8..ec596200d1edc 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -503,7 +503,7 @@ func TestAthenaAuditLogSetup(t *testing.T) { oidcIntegration, err := types.NewIntegrationAWSOIDC( types.Metadata{Name: "aws-integration-1"}, &types.AWSOIDCIntegrationSpecV1{ - RoleARN: "role1", + RoleARN: "arn:aws:iam::account:role/role1", }, ) require.NoError(t, err)