Skip to content

Commit

Permalink
Add rate limiting to OTP login attempts and update email template for…
Browse files Browse the repository at this point in the history
… conciseness
  • Loading branch information
sabaimran committed Dec 16, 2024
1 parent b778335 commit ae9750e
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 10 deletions.
25 changes: 23 additions & 2 deletions src/interface/web/app/components/loginPrompt/loginPrompt.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ export interface LoginPromptProps {

const fetcher = (url: string) => fetch(url).then((res) => res.json());

const ALLOWED_OTP_ATTEMPTS = 5;

interface Provider {
client_id: string;
redirect_uri: string;
Expand Down Expand Up @@ -230,10 +232,16 @@ function EmailSignInContext({
}) {
const [otp, setOTP] = useState("");
const [otpError, setOTPError] = useState("");
const [numFailures, setNumFailures] = useState(0);

function checkOTPAndRedirect() {
const verifyUrl = `/auth/magic?code=${otp}&email=${email}`;

if (numFailures >= ALLOWED_OTP_ATTEMPTS) {
setOTPError("Too many failed attempts. Please try again tomorrow.");
return;
}

fetch(verifyUrl, {
method: "GET",
headers: {
Expand All @@ -246,8 +254,16 @@ function EmailSignInContext({
if (res.redirected) {
window.location.href = res.url;
}
} else if (res.status === 401) {
setOTPError("Invalid OTP.");
setNumFailures(numFailures + 1);
if (numFailures + 1 >= ALLOWED_OTP_ATTEMPTS) {
setOTPError("Too many failed attempts. Please try again tomorrow.");
}
} else if (res.status === 429) {
setOTPError("Too many failed attempts. Please try again tomorrow.");
setNumFailures(ALLOWED_OTP_ATTEMPTS);
} else {
setOTPError("Invalid OTP");
throw new Error("Failed to verify OTP");
}
})
Expand Down Expand Up @@ -309,6 +325,7 @@ function EmailSignInContext({
maxLength={6}
value={otp || ""}
onChange={setOTP}
disabled={numFailures >= ALLOWED_OTP_ATTEMPTS}
onComplete={() =>
setTimeout(() => {
checkOTPAndRedirect();
Expand All @@ -324,7 +341,11 @@ function EmailSignInContext({
<InputOTPSlot index={5} />
</InputOTPGroup>
</InputOTP>
<div className="text-red-500 text-sm">{otpError}</div>
{otpError && (
<div className="text-red-500 text-sm">
{otpError} {ALLOWED_OTP_ATTEMPTS - numFailures} remaining attempts.
</div>
)}
</div>
)}

Expand Down
6 changes: 5 additions & 1 deletion src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,10 +437,14 @@ def is_user_subscribed(user: KhojUser) -> bool:
return subscribed


async def get_user_by_email(email: str) -> KhojUser:
async def aget_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst()


def get_user_by_email(email: str) -> KhojUser:
return KhojUser.objects.filter(email=email).first()


async def aget_user_by_uuid(uuid: str) -> KhojUser:
return await KhojUser.objects.filter(uuid=uuid).afirst()

Expand Down
4 changes: 2 additions & 2 deletions src/khoj/interface/email/magic_link.html
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
<img src="https://assets.khoj.dev/khoj_logo.png" alt="Khoj Logo" style="width: 120px;">
</a>

<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Hi!</p>

<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Use this code (valid for 5 minutes) to login to Khoj:</p>

<h1 style="font-size: 24px; color: #2c3e50; margin-bottom: 20px; text-align: center;">{{ code }}</h1>

<p style="font-size: 16px; color: #333; margin-bottom: 20px;">It will be valid for 5 minutes.</p>

<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Alternatively, <a href="{{ link }}" target="_blank"
style="color: #FFA07A; text-decoration: none; font-weight: bold;">Click here to sign in on this
browser.</a></p>
Expand Down
19 changes: 15 additions & 4 deletions src/khoj/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Optional

import requests
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from pydantic import BaseModel, EmailStr
from starlette.authentication import requires
from starlette.config import Config
Expand All @@ -22,7 +22,11 @@
get_or_create_user,
)
from khoj.routers.email import send_magic_link_email, send_welcome_email
from khoj.routers.helpers import get_next_url, update_telemetry_state
from khoj.routers.helpers import (
EmailVerificationApiRateLimiter,
get_next_url,
update_telemetry_state,
)
from khoj.utils import state

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,7 +103,14 @@ async def login_magic_link(request: Request, form: MagicLinkForm):


@auth_router.get("/magic")
async def sign_in_with_magic_link(request: Request, code: str, email: str):
async def sign_in_with_magic_link(
request: Request,
code: str,
email: str,
rate_limiter=Depends(
EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_verification")
),
):
user = await aget_user_validated_by_email_verification_code(code, email)
if user:
id_info = {
Expand All @@ -108,7 +119,7 @@ async def sign_in_with_magic_link(request: Request, code: str, email: str):

request.session["user"] = dict(id_info)
return RedirectResponse(url="/")
return RedirectResponse(request.app.url_path_for("login_page"))
return Response(status_code=401)


@auth_router.post("/token")
Expand Down
2 changes: 1 addition & 1 deletion src/khoj/routers/email.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def send_magic_link_email(email, unique_id, host):
{
"sender": os.environ.get("RESEND_EMAIL", "[email protected]"),
"to": email,
"subject": f"{unique_id} - Sign in to Khoj 🚀",
"subject": f"Your unique login to Khoj",
"html": html_content,
}
)
Expand Down
44 changes: 44 additions & 0 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
ais_user_subscribed,
create_khoj_token,
get_khoj_tokens,
get_user_by_email,
get_user_name,
get_user_notion_config,
get_user_subscription_state,
Expand Down Expand Up @@ -1363,6 +1364,49 @@ class FeedbackData(BaseModel):
sentiment: str


class EmailVerificationApiRateLimiter:
def __init__(self, requests: int, window: int, slug: str):
self.requests = requests
self.window = window
self.slug = slug

def __call__(self, request: Request):
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return

# Extract the email query parameter
email = request.query_params.get("email")

if email:
logger.info(f"Email query parameter: {email}")

user: KhojUser = get_user_by_email(email)

if not user:
raise HTTPException(
status_code=404,
detail="User not found.",
)

# Remove requests outside of the time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()

# Check if the user has exceeded the rate limit
if count_requests >= self.requests:
logger.info(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
)
raise HTTPException(
status_code=429,
detail="Ran out of login attempts",
)

# Add the current request to the db
UserRequests.objects.create(user=user, slug=self.slug)


class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
Expand Down

0 comments on commit ae9750e

Please sign in to comment.