Skip to content

Commit

Permalink
hard-delete users (#4715)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Sep 21, 2024
1 parent 946b9c5 commit 41024b8
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ def upgrade() -> None:
server_default=sa.func.now(),
onupdate=sa.func.now(),
),
sa.Column(
"deleted_at",
sa.TIMESTAMP(timezone=True),
nullable=True,
),
sa.CheckConstraint(
"(password_hash IS NULL) = (password_salt IS NULL)",
name="password_hash_and_salt",
Expand Down
1 change: 0 additions & 1 deletion src/phoenix/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ class User(Base):
updated_at: Mapped[datetime] = mapped_column(
UtcTimeStamp, server_default=func.now(), onupdate=func.now()
)
deleted_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp)
password_reset_token: Mapped[Optional["PasswordResetToken"]] = relationship(
"PasswordResetToken",
back_populates="user",
Expand Down
6 changes: 2 additions & 4 deletions src/phoenix/server/api/dataloaders/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import defaultdict
from typing import DefaultDict, List, Optional

from sqlalchemy import and_, select
from sqlalchemy import select
from strawberry.dataloader import DataLoader
from typing_extensions import TypeAlias

Expand All @@ -25,9 +25,7 @@ async def _load_fn(self, keys: List[Key]) -> List[Result]:
users_by_id: DefaultDict[Key, Result] = defaultdict(None)
async with self._db() as session:
data = await session.stream_scalars(
select(models.User).where(
and_(models.User.id.in_(user_ids), models.User.deleted_at.is_(None))
)
select(models.User).where(models.User.id.in_(user_ids))
)
async for user in data:
users_by_id[user.id] = user
Expand Down
6 changes: 2 additions & 4 deletions src/phoenix/server/api/mutations/api_key_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import strawberry
from sqlalchemy import and_, select
from sqlalchemy import select
from strawberry import UNSET
from strawberry.relay import GlobalID
from strawberry.types import Info
Expand Down Expand Up @@ -71,9 +71,7 @@ async def create_system_api_key(
system_user = await session.scalar(
select(models.User)
.join(models.UserRole) # Join User with UserRole
.where(
and_(models.UserRole.name == user_role.value, models.User.deleted_at.is_(None))
) # Filter where role is SYSTEM
.where(models.UserRole.name == user_role.value) # Filter where role is SYSTEM
.order_by(models.User.id)
.limit(1)
)
Expand Down
31 changes: 9 additions & 22 deletions src/phoenix/server/api/mutations/user_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Literal, Optional, Tuple

import strawberry
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select, update
from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select
from sqlalchemy.orm import joinedload
from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped]
from strawberry import UNSET
Expand Down Expand Up @@ -258,7 +258,6 @@ async def delete_users(
.where(
and_(
models.User.id.in_(user_ids),
models.User.deleted_at.is_(None),
models.User.user_role_id != system_user_role_id,
)
)
Expand All @@ -271,40 +270,30 @@ async def delete_users(
password_reset_token_ids = [
PasswordResetTokenId(id_)
async for id_ in await session.stream_scalars(
delete(models.PasswordResetToken)
.where(models.PasswordResetToken.user_id.in_(user_ids))
.returning(models.PasswordResetToken.id)
select(models.PasswordResetToken.id).where(
models.PasswordResetToken.user_id.in_(user_ids)
)
)
]
access_token_ids = [
AccessTokenId(id_)
async for id_ in await session.stream_scalars(
delete(models.AccessToken)
.where(models.AccessToken.user_id.in_(user_ids))
.returning(models.AccessToken.id)
select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids))
)
]
refresh_token_ids = [
RefreshTokenId(id_)
async for id_ in await session.stream_scalars(
delete(models.RefreshToken)
.where(models.RefreshToken.user_id.in_(user_ids))
.returning(models.RefreshToken.id)
select(models.RefreshToken.id).where(models.RefreshToken.user_id.in_(user_ids))
)
]
api_key_ids = [
ApiKeyId(id_)
async for id_ in await session.stream_scalars(
delete(models.ApiKey)
.where(models.ApiKey.user_id.in_(user_ids))
.returning(models.ApiKey.id)
select(models.ApiKey.id).where(models.ApiKey.user_id.in_(user_ids))
)
]
await session.execute(
update(models.User)
.where(models.User.id.in_(user_ids))
.values(deleted_at=func.now())
)
await session.execute(delete(models.User).where(models.User.id.in_(user_ids)))
await token_store.revoke(
*password_reset_token_ids,
*access_token_ids,
Expand All @@ -319,9 +308,7 @@ def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]:

def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]:
return (
select(models.User)
.where(and_(models.User.id == user_id, models.User.deleted_at.is_(None)))
.options(joinedload(models.User.role))
select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role))
)


Expand Down
15 changes: 3 additions & 12 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,7 @@ async def users(
stmt = (
select(models.User)
.join(models.UserRole)
.where(
and_(
models.UserRole.name != enums.UserRole.SYSTEM.value,
models.User.deleted_at.is_(None),
)
)
.where(models.UserRole.name != enums.UserRole.SYSTEM.value)
.order_by(models.User.email)
.options(joinedload(models.User.role))
)
Expand Down Expand Up @@ -477,9 +472,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node:
async with info.context.db() as session:
if not (
user := await session.scalar(
select(models.User).where(
and_(models.User.id == node_id, models.User.deleted_at.is_(None))
)
select(models.User).where(models.User.id == node_id)
)
):
raise NotFound(f"Unknown user: {id}")
Expand All @@ -499,9 +492,7 @@ async def viewer(self, info: Info[Context, None]) -> Optional[User]:
if (
user := await session.scalar(
select(models.User)
.where(
and_(models.User.id == int(user.identity), models.User.deleted_at.is_(None))
)
.where(models.User.id == int(user.identity))
.options(joinedload(models.User.role))
)
) is None:
Expand Down
15 changes: 5 additions & 10 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@
import secrets
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import Tuple

from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sqlalchemy import Select, select
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from starlette.status import (
HTTP_204_NO_CONTENT,
Expand Down Expand Up @@ -80,7 +79,7 @@ async def login(request: Request) -> Response:

async with request.app.state.db() as session:
user = await session.scalar(
_select_active_user().filter_by(email=email).options(joinedload(models.User.role))
select(models.User).filter_by(email=email).options(joinedload(models.User.role))
)
if (
user is None
Expand Down Expand Up @@ -160,7 +159,7 @@ async def refresh_tokens(request: Request) -> Response:
async with request.app.state.db() as session:
if (
user := await session.scalar(
_select_active_user().filter_by(id=user_id).options(joinedload(models.User.role))
select(models.User).filter_by(id=user_id).options(joinedload(models.User.role))
)
) is None:
raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found")
Expand Down Expand Up @@ -191,7 +190,7 @@ async def initiate_password_reset(request: Request) -> Response:
assert isinstance(token_expiry := request.app.state.password_reset_token_expiry, timedelta)
async with request.app.state.db() as session:
user = await session.scalar(
_select_active_user()
select(models.User)
.filter_by(email=email)
.options(
joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id)
Expand Down Expand Up @@ -228,7 +227,7 @@ async def reset_password(request: Request) -> Response:
raise INVALID_TOKEN
assert (user_id := claims.subject)
async with request.app.state.db() as session:
user = await session.scalar(_select_active_user().filter_by(id=int(user_id)))
user = await session.scalar(select(models.User).filter_by(id=int(user_id)))
if user is None or user.auth_method != enums.AuthMethod.LOCAL.value:
# Withold privileged information
return Response(status_code=HTTP_204_NO_CONTENT)
Expand All @@ -249,10 +248,6 @@ async def reset_password(request: Request) -> Response:
return response


def _select_active_user() -> Select[Tuple[models.User]]:
return select(models.User).where(models.User.deleted_at.is_(None))


LOGIN_FAILED_MESSAGE = "Invalid email and/or password"

MISSING_EMAIL = HTTPException(
Expand Down

0 comments on commit 41024b8

Please sign in to comment.