diff --git a/bun.lockb b/bun.lockb new file mode 100755 index 000000000..59585c93d Binary files /dev/null and b/bun.lockb differ diff --git a/packages/api/src/auth/hono.ts b/packages/api/src/auth/hono.ts deleted file mode 100644 index 03ef9b4f0..000000000 --- a/packages/api/src/auth/hono.ts +++ /dev/null @@ -1,22 +0,0 @@ -import { getAllowedOriginHost } from '.' -import type { Context as HonoContext, Next } from 'hono' -import { Bindings } from '../worker' -import { verifyRequestOrigin } from 'oslo/request' - -export const csrfMiddleware = async (c: HonoContext<{ Bindings: Bindings }>, next: Next) => { - // CSRF middleware - if (c.req.method === 'GET') { - return next() - } - const originHeader = c.req.header('origin') - const hostHeader = c.req.header('host') - const allowedOrigin = getAllowedOriginHost(c.env.APP_URL, c.req.raw) - if ( - !originHeader || - !hostHeader || - !verifyRequestOrigin(originHeader, [hostHeader, ...(allowedOrigin ? [allowedOrigin] : [])]) - ) { - return c.body(null, 403) - } - return next() -} diff --git a/packages/api/src/context.ts b/packages/api/src/context.ts index b58a54215..6d1bae6ce 100644 --- a/packages/api/src/context.ts +++ b/packages/api/src/context.ts @@ -5,9 +5,9 @@ import type { User } from './db/schema' import { Bindings } from './worker' import type { inferAsyncReturnType } from '@trpc/server' import type { Context as HonoContext, HonoRequest } from 'hono' -import type { Lucia } from 'lucia' +import { verifyRequestOrigin, type Lucia } from 'lucia' import { verifyToken } from './utils/crypto' -import { createAuth } from './auth' +import { createAuth, getAllowedOriginHost } from './auth' import { getCookie } from 'hono/cookie' export interface ApiContextProps { @@ -16,6 +16,7 @@ export interface ApiContextProps { auth: Lucia req?: HonoRequest c?: HonoContext + enableTokens: boolean setCookie: (value: string) => void db: DB env: Bindings @@ -63,6 +64,7 @@ export const createContext = async ( // const user = await getUser() const auth = createAuth(db, env.APP_URL) + const enableTokens = Boolean(context.req.header('x-enable-tokens')) async function getSession() { let user: User | undefined @@ -75,28 +77,44 @@ export const createContext = async ( if (!context.req) return res const cookieSessionId = getCookie(context, auth.sessionCookieName) - const bearerSessionId = - !cookieSessionId && - context.req.header('x-enable-tokens') && - context.req.header('authorization')?.split(' ')[1] + const bearerSessionId = enableTokens && context.req.header('authorization')?.split(' ')[1] if (!cookieSessionId && !bearerSessionId) return res - const authResult = await auth.validateSession(cookieSessionId || bearerSessionId || '') - if (cookieSessionId) { - if (authResult.session?.fresh) { - context.header('Set-Cookie', auth.createSessionCookie(authResult.session.id).serialize(), { - append: true, - }) - } - if (!session) { - context.header('Set-Cookie', auth.createBlankSessionCookie().serialize(), { - append: true, - }) + let authResult: Awaited> | undefined + if (cookieSessionId && !enableTokens) { + const originHeader = context.req.header('origin') + const hostHeader = context.req.header('host') + const allowedOrigin = getAllowedOriginHost(context.env.APP_URL, context.req.raw) + if ( + originHeader && + hostHeader && + verifyRequestOrigin(originHeader, [hostHeader, ...(allowedOrigin ? [allowedOrigin] : [])]) + ) { + authResult = await auth.validateSession(cookieSessionId) + if (authResult.session?.fresh) { + context.header( + 'Set-Cookie', + auth.createSessionCookie(authResult.session.id).serialize(), + { + append: true, + } + ) + } + if (!authResult?.session) { + context.header('Set-Cookie', auth.createBlankSessionCookie().serialize(), { + append: true, + }) + } + } else { + console.log('CSRF failed', { cookieSessionId, originHeader, hostHeader, allowedOrigin }) } } - res.session = authResult.session || undefined - res.user = authResult.user || undefined + if (bearerSessionId) { + authResult = await auth.validateSession(bearerSessionId) + } + res.session = authResult?.session || undefined + res.user = authResult?.user || undefined return res } @@ -108,6 +126,7 @@ export const createContext = async ( c: context, session, user, + enableTokens, setCookie: (value) => { resHeaders.append('set-cookie', value) }, diff --git a/packages/api/src/routes/user.ts b/packages/api/src/routes/user.ts index 6e6744937..304ec8ed9 100644 --- a/packages/api/src/routes/user.ts +++ b/packages/api/src/routes/user.ts @@ -186,7 +186,7 @@ const signInWithAppleIdTokenHandler = idTokenClaims: payload, }) const session = await createSession(ctx.auth, user.id) - if (ctx.setCookie) { + if (ctx.enableTokens && ctx.setCookie) { ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) } return { session } @@ -247,7 +247,7 @@ const signInWithEmailCodeHandler = console.log('calling update passing and invalidate sessions') await ctx.auth.invalidateUserSessions(res.session?.userId) const session = await createSession(ctx.auth, res.session?.userId) - if (ctx.setCookie) { + if (ctx.enableTokens && ctx.setCookie) { ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) } res.session = session @@ -342,7 +342,7 @@ export const userRouter = router({ email: input.email, }) const session = await createSession(ctx.auth, user.id) - if (ctx.setCookie) { + if (ctx.enableTokens && ctx.setCookie) { ctx.setCookie(ctx.auth.createSessionCookie(session.id).serialize()) } return { session } diff --git a/packages/api/src/worker.ts b/packages/api/src/worker.ts index 2f0d71031..8575d44cc 100644 --- a/packages/api/src/worker.ts +++ b/packages/api/src/worker.ts @@ -3,7 +3,6 @@ import { appRouter } from '@t4/api/src/router' import { cors } from 'hono/cors' import { createContext } from '@t4/api/src/context' import { trpcServer } from '@hono/trpc-server' -import { csrfMiddleware } from './auth/hono' export type Bindings = Env & { JWT_VERIFICATION_KEY: string @@ -42,8 +41,6 @@ const corsHandler = async (c: Context<{ Bindings: Bindings }>, next: Next) => { })(c, next) } -app.use('*', csrfMiddleware) - // Setup CORS for the frontend app.use('/trpc/*', corsHandler) diff --git a/packages/app/features/oauth/screen.tsx b/packages/app/features/oauth/screen.tsx index 98c972ff2..806e002ea 100644 --- a/packages/app/features/oauth/screen.tsx +++ b/packages/app/features/oauth/screen.tsx @@ -38,7 +38,6 @@ export interface OAuthSignInScreenProps { appleUser?: { email?: string | null } | null } - export const OAuthSignInScreen = ({ appleUser }: OAuthSignInScreenProps): React.ReactNode => { const sent = useRef(false) const { signIn } = useSignIn() @@ -83,9 +82,11 @@ export const OAuthSignInScreen = ({ appleUser }: OAuthSignInScreenProps): React. code, // undefined vs null is a result of passing via JSON with getServerSideProps // Maybe there's a superjson plugin or another way to handle it. - appleUser: appleUser ? { - email: appleUser.email || undefined, - } : undefined, + appleUser: appleUser + ? { + email: appleUser.email || undefined, + } + : undefined, }) }, [provider, redirectTo, state, code, sendApiRequestOnLoad, appleUser]) diff --git a/packages/app/utils/auth/index.ts b/packages/app/utils/auth/index.ts index 0ec14e608..1859b807f 100644 --- a/packages/app/utils/auth/index.ts +++ b/packages/app/utils/auth/index.ts @@ -181,18 +181,14 @@ export function isSignInWithOAuth(props: SignInProps): props is SignInWithOAuth } export function useSignIn() { - // TODO ^ maybe accept props for what to do after sign in? + // TODO ^ maybe accept props for what to do after sign in const mutation = trpc.user.signIn.useMutation() const utils = trpc.useUtils() const postLogin = (res: SignInResult) => { - // We _could_ call setSessionAtom here and store the session ID for native... - // but it's a bit simpler to centralize all of the session - // logic around the user.session trpc query (see useSessionContext() above). - // It causes a redundant db queries and API roundtrip though... - // It might be possible to update the trpc user.session cache manually - // so we can keep the logic centralized around the user.session query while - // avoiding the additional request. + if (!isWeb) { + storeSessionToken(res.session?.id) + } utils.user.invalidate() utils.auth.invalidate() } @@ -245,6 +241,9 @@ export function useSignUp() { const mutation = trpc.user.create.useMutation() const utils = trpc.useUtils() const postLogin = (res: SignInResult) => { + if (!isWeb) { + storeSessionToken(res.session?.id) + } utils.user.invalidate() utils.auth.invalidate() }