diff --git a/apps/backend/src/auth/request.ts b/apps/backend/src/auth/request.ts index 86f8f7b81..fab118836 100644 --- a/apps/backend/src/auth/request.ts +++ b/apps/backend/src/auth/request.ts @@ -1,21 +1,25 @@ +import { invariant } from "@argos/util/invariant"; import * as authorization from "auth-header"; import type { Request } from "express"; import { Account, User } from "@/database/models/index.js"; +import { boom } from "@/web/util.js"; import { verifyJWT } from "./jwt.js"; -export class AuthError extends Error {} - const getTokenFromAuthHeader = (authHeader: string) => { - const auth = authorization.parse(authHeader); - if (auth.scheme !== "Bearer") { - throw new AuthError(`Invalid auth scheme: ${auth.scheme || "no scheme"}`); - } - if (typeof auth.token !== "string" || !auth.token) { - throw new AuthError("Invalid auth token"); + try { + const auth = authorization.parse(authHeader); + if (auth.scheme !== "Bearer") { + return null; + } + if (typeof auth.token !== "string" || !auth.token) { + return null; + } + return auth.token; + } catch { + return null; } - return auth.token; }; export type AuthPayload = { @@ -26,17 +30,15 @@ export type AuthPayload = { const getAuthPayloadFromToken = async (token: string): Promise => { const jwt = verifyJWT(token); if (!jwt) { - throw new AuthError("Invalid JWT"); + throw boom(401, "Invalid JWT"); } const account = await Account.query() .withGraphFetched("user") .findById(jwt.account.id); if (!account) { - throw new AuthError("Account not found"); - } - if (!account.user) { - throw new AuthError("Account has no user"); + throw boom(401, "Account not found"); } + invariant(account.user, "Account has no user"); return { account, user: account.user }; }; @@ -48,5 +50,8 @@ export async function getAuthPayloadFromRequest( return null; } const token = getTokenFromAuthHeader(authHeader); + if (!token) { + return null; + } return getAuthPayloadFromToken(token); } diff --git a/apps/backend/src/graphql/context.ts b/apps/backend/src/graphql/context.ts index 846d6e973..515321806 100644 --- a/apps/backend/src/graphql/context.ts +++ b/apps/backend/src/graphql/context.ts @@ -2,11 +2,8 @@ import type { BaseContext } from "@apollo/server"; import type { Request } from "express"; import { GraphQLError } from "graphql"; -import { - AuthError, - AuthPayload, - getAuthPayloadFromRequest, -} from "@/auth/request.js"; +import { AuthPayload, getAuthPayloadFromRequest } from "@/auth/request.js"; +import { HTTPError } from "@/web/util.js"; import { createLoaders } from "./loaders.js"; @@ -28,7 +25,7 @@ export async function getContext(request: Request): Promise { const auth = await getContextAuth(request); return { auth, loaders: createLoaders() }; } catch (error) { - if (error instanceof AuthError) { + if (error instanceof HTTPError && error.statusCode === 401) { throw new GraphQLError("User is not authenticated", { originalError: error, extensions: { diff --git a/apps/backend/src/web/middlewares/auth.ts b/apps/backend/src/web/middlewares/auth.ts index bb9eec02c..40ac8bc37 100644 --- a/apps/backend/src/web/middlewares/auth.ts +++ b/apps/backend/src/web/middlewares/auth.ts @@ -4,7 +4,7 @@ import type { RequestHandler } from "express"; import { AuthPayload, getAuthPayloadFromRequest } from "@/auth/request.js"; -import { asyncHandler } from "../util.js"; +import { asyncHandler, HTTPError } from "../util.js"; declare global { namespace Express { @@ -15,7 +15,13 @@ declare global { } export const auth: RequestHandler = asyncHandler(async (req, _res, next) => { - const account = await getAuthPayloadFromRequest(req); + const account = await getAuthPayloadFromRequest(req).catch((error) => { + if (error instanceof HTTPError && error.statusCode === 401) { + return null; + } + + throw error; + }); req.auth = account; next(); }); diff --git a/apps/backend/src/web/util.ts b/apps/backend/src/web/util.ts index 7e8ffcef9..9106cc927 100644 --- a/apps/backend/src/web/util.ts +++ b/apps/backend/src/web/util.ts @@ -11,7 +11,11 @@ import config from "@/config/index.js"; export const asyncHandler = (requestHandler: RequestHandler): RequestHandler => (req, res, next) => { - Promise.resolve(requestHandler(req, res, next)).catch(next); + try { + Promise.resolve(requestHandler(req, res, next)).catch(next); + } catch (error) { + next(error); + } }; export const subdomain = @@ -49,7 +53,7 @@ type HttpErrorOptions = ErrorOptions & { /** * HTTPError is a subclass of Error that includes an HTTP status code. */ -class HTTPError extends Error { +export class HTTPError extends Error { public statusCode: number; public details: | {