diff --git a/src/interface/web/app/components/loginPrompt/loginPrompt.tsx b/src/interface/web/app/components/loginPrompt/loginPrompt.tsx index 0aa150ea5..b372a6ef0 100644 --- a/src/interface/web/app/components/loginPrompt/loginPrompt.tsx +++ b/src/interface/web/app/components/loginPrompt/loginPrompt.tsx @@ -36,6 +36,8 @@ export interface LoginPromptProps { const fetcher = (url: string) => fetch(url).then((res) => res.json()); +const ALLOWED_OTP_ATTEMPTS = 5; + interface Provider { client_id: string; redirect_uri: string; @@ -230,10 +232,16 @@ function EmailSignInContext({ }) { const [otp, setOTP] = useState(""); const [otpError, setOTPError] = useState(""); + const [numFailures, setNumFailures] = useState(0); function checkOTPAndRedirect() { const verifyUrl = `/auth/magic?code=${otp}&email=${email}`; + if (numFailures >= ALLOWED_OTP_ATTEMPTS) { + setOTPError("Too many failed attempts. Please try again tomorrow."); + return; + } + fetch(verifyUrl, { method: "GET", headers: { @@ -246,8 +254,16 @@ function EmailSignInContext({ if (res.redirected) { window.location.href = res.url; } + } else if (res.status === 401) { + setOTPError("Invalid OTP."); + setNumFailures(numFailures + 1); + if (numFailures + 1 >= ALLOWED_OTP_ATTEMPTS) { + setOTPError("Too many failed attempts. Please try again tomorrow."); + } + } else if (res.status === 429) { + setOTPError("Too many failed attempts. Please try again tomorrow."); + setNumFailures(ALLOWED_OTP_ATTEMPTS); } else { - setOTPError("Invalid OTP"); throw new Error("Failed to verify OTP"); } }) @@ -309,6 +325,7 @@ function EmailSignInContext({ maxLength={6} value={otp || ""} onChange={setOTP} + disabled={numFailures >= ALLOWED_OTP_ATTEMPTS} onComplete={() => setTimeout(() => { checkOTPAndRedirect(); @@ -324,7 +341,11 @@ function EmailSignInContext({ -
{otpError}
+ {otpError && ( +
+ {otpError} {ALLOWED_OTP_ATTEMPTS - numFailures} remaining attempts. +
+ )} )} diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 82f3cda41..28f1d2d92 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -437,10 +437,14 @@ def is_user_subscribed(user: KhojUser) -> bool: return subscribed -async def get_user_by_email(email: str) -> KhojUser: +async def aget_user_by_email(email: str) -> KhojUser: return await KhojUser.objects.filter(email=email).afirst() +def get_user_by_email(email: str) -> KhojUser: + return KhojUser.objects.filter(email=email).first() + + async def aget_user_by_uuid(uuid: str) -> KhojUser: return await KhojUser.objects.filter(uuid=uuid).afirst() diff --git a/src/khoj/interface/email/magic_link.html b/src/khoj/interface/email/magic_link.html index 72a656792..148580230 100644 --- a/src/khoj/interface/email/magic_link.html +++ b/src/khoj/interface/email/magic_link.html @@ -16,12 +16,12 @@ Khoj Logo -

Hi!

-

Use this code (valid for 5 minutes) to login to Khoj:

{{ code }}

+

It will be valid for 5 minutes.

+

Alternatively, Click here to sign in on this browser.

diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py index a01bf6fbd..b509335f0 100644 --- a/src/khoj/routers/auth.py +++ b/src/khoj/routers/auth.py @@ -5,7 +5,7 @@ from typing import Optional import requests -from fastapi import APIRouter +from fastapi import APIRouter, Depends from pydantic import BaseModel, EmailStr from starlette.authentication import requires from starlette.config import Config @@ -22,7 +22,11 @@ get_or_create_user, ) from khoj.routers.email import send_magic_link_email, send_welcome_email -from khoj.routers.helpers import get_next_url, update_telemetry_state +from khoj.routers.helpers import ( + EmailVerificationApiRateLimiter, + get_next_url, + update_telemetry_state, +) from khoj.utils import state logger = logging.getLogger(__name__) @@ -99,7 +103,14 @@ async def login_magic_link(request: Request, form: MagicLinkForm): @auth_router.get("/magic") -async def sign_in_with_magic_link(request: Request, code: str, email: str): +async def sign_in_with_magic_link( + request: Request, + code: str, + email: str, + rate_limiter=Depends( + EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_verification") + ), +): user = await aget_user_validated_by_email_verification_code(code, email) if user: id_info = { @@ -108,7 +119,7 @@ async def sign_in_with_magic_link(request: Request, code: str, email: str): request.session["user"] = dict(id_info) return RedirectResponse(url="/") - return RedirectResponse(request.app.url_path_for("login_page")) + return Response(status_code=401) @auth_router.post("/token") diff --git a/src/khoj/routers/email.py b/src/khoj/routers/email.py index 79061da0d..1f60a1ba8 100644 --- a/src/khoj/routers/email.py +++ b/src/khoj/routers/email.py @@ -47,7 +47,7 @@ async def send_magic_link_email(email, unique_id, host): { "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"), "to": email, - "subject": f"{unique_id} - Sign in to Khoj 🚀", + "subject": f"Your unique login to Khoj", "html": html_content, } ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ecd1f1e40..af0e980dd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -49,6 +49,7 @@ ais_user_subscribed, create_khoj_token, get_khoj_tokens, + get_user_by_email, get_user_name, get_user_notion_config, get_user_subscription_state, @@ -1363,6 +1364,49 @@ class FeedbackData(BaseModel): sentiment: str +class EmailVerificationApiRateLimiter: + def __init__(self, requests: int, window: int, slug: str): + self.requests = requests + self.window = window + self.slug = slug + + def __call__(self, request: Request): + # Rate limiting disabled if billing is disabled + if state.billing_enabled is False: + return + + # Extract the email query parameter + email = request.query_params.get("email") + + if email: + logger.info(f"Email query parameter: {email}") + + user: KhojUser = get_user_by_email(email) + + if not user: + raise HTTPException( + status_code=404, + detail="User not found.", + ) + + # Remove requests outside of the time window + cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window) + count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() + + # Check if the user has exceeded the rate limit + if count_requests >= self.requests: + logger.info( + f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}." + ) + raise HTTPException( + status_code=429, + detail="Ran out of login attempts", + ) + + # Add the current request to the db + UserRequests.objects.create(user=user, slug=self.slug) + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): self.requests = requests