From 41024b84552ae29094d00e5cc25cf635776b7623 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Sat, 21 Sep 2024 15:54:49 -0700 Subject: [PATCH] hard-delete users (#4715) --- .../versions/cd164e83824f_users_and_tokens.py | 5 --- src/phoenix/db/models.py | 1 - src/phoenix/server/api/dataloaders/users.py | 6 ++-- .../server/api/mutations/api_key_mutations.py | 6 ++-- .../server/api/mutations/user_mutations.py | 31 ++++++------------- src/phoenix/server/api/queries.py | 15 ++------- src/phoenix/server/api/routers/auth.py | 15 +++------ 7 files changed, 21 insertions(+), 58 deletions(-) diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py index b61d98271c..60f4f03414 100644 --- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py @@ -61,11 +61,6 @@ def upgrade() -> None: server_default=sa.func.now(), onupdate=sa.func.now(), ), - sa.Column( - "deleted_at", - sa.TIMESTAMP(timezone=True), - nullable=True, - ), sa.CheckConstraint( "(password_hash IS NULL) = (password_salt IS NULL)", name="password_hash_and_salt", diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index bdf12cd4b3..278a5d6570 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -654,7 +654,6 @@ class User(Base): updated_at: Mapped[datetime] = mapped_column( UtcTimeStamp, server_default=func.now(), onupdate=func.now() ) - deleted_at: Mapped[Optional[datetime]] = mapped_column(UtcTimeStamp) password_reset_token: Mapped[Optional["PasswordResetToken"]] = relationship( "PasswordResetToken", back_populates="user", diff --git a/src/phoenix/server/api/dataloaders/users.py b/src/phoenix/server/api/dataloaders/users.py index 86a36c7ea9..7bbe34b69c 100644 --- a/src/phoenix/server/api/dataloaders/users.py +++ b/src/phoenix/server/api/dataloaders/users.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import DefaultDict, List, Optional -from sqlalchemy import and_, select +from sqlalchemy import select from strawberry.dataloader import DataLoader from typing_extensions import TypeAlias @@ -25,9 +25,7 @@ async def _load_fn(self, keys: List[Key]) -> List[Result]: users_by_id: DefaultDict[Key, Result] = defaultdict(None) async with self._db() as session: data = await session.stream_scalars( - select(models.User).where( - and_(models.User.id.in_(user_ids), models.User.deleted_at.is_(None)) - ) + select(models.User).where(models.User.id.in_(user_ids)) ) async for user in data: users_by_id[user.id] = user diff --git a/src/phoenix/server/api/mutations/api_key_mutations.py b/src/phoenix/server/api/mutations/api_key_mutations.py index ee4c5f165d..d7877c9418 100644 --- a/src/phoenix/server/api/mutations/api_key_mutations.py +++ b/src/phoenix/server/api/mutations/api_key_mutations.py @@ -2,7 +2,7 @@ from typing import Optional import strawberry -from sqlalchemy import and_, select +from sqlalchemy import select from strawberry import UNSET from strawberry.relay import GlobalID from strawberry.types import Info @@ -71,9 +71,7 @@ async def create_system_api_key( system_user = await session.scalar( select(models.User) .join(models.UserRole) # Join User with UserRole - .where( - and_(models.UserRole.name == user_role.value, models.User.deleted_at.is_(None)) - ) # Filter where role is SYSTEM + .where(models.UserRole.name == user_role.value) # Filter where role is SYSTEM .order_by(models.User.id) .limit(1) ) diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 0af7e8e9b0..874416b721 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -4,7 +4,7 @@ from typing import List, Literal, Optional, Tuple import strawberry -from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select, update +from sqlalchemy import Boolean, Select, and_, case, cast, delete, distinct, func, select from sqlalchemy.orm import joinedload from sqlean.dbapi2 import IntegrityError # type: ignore[import-untyped] from strawberry import UNSET @@ -258,7 +258,6 @@ async def delete_users( .where( and_( models.User.id.in_(user_ids), - models.User.deleted_at.is_(None), models.User.user_role_id != system_user_role_id, ) ) @@ -271,40 +270,30 @@ async def delete_users( password_reset_token_ids = [ PasswordResetTokenId(id_) async for id_ in await session.stream_scalars( - delete(models.PasswordResetToken) - .where(models.PasswordResetToken.user_id.in_(user_ids)) - .returning(models.PasswordResetToken.id) + select(models.PasswordResetToken.id).where( + models.PasswordResetToken.user_id.in_(user_ids) + ) ) ] access_token_ids = [ AccessTokenId(id_) async for id_ in await session.stream_scalars( - delete(models.AccessToken) - .where(models.AccessToken.user_id.in_(user_ids)) - .returning(models.AccessToken.id) + select(models.AccessToken.id).where(models.AccessToken.user_id.in_(user_ids)) ) ] refresh_token_ids = [ RefreshTokenId(id_) async for id_ in await session.stream_scalars( - delete(models.RefreshToken) - .where(models.RefreshToken.user_id.in_(user_ids)) - .returning(models.RefreshToken.id) + select(models.RefreshToken.id).where(models.RefreshToken.user_id.in_(user_ids)) ) ] api_key_ids = [ ApiKeyId(id_) async for id_ in await session.stream_scalars( - delete(models.ApiKey) - .where(models.ApiKey.user_id.in_(user_ids)) - .returning(models.ApiKey.id) + select(models.ApiKey.id).where(models.ApiKey.user_id.in_(user_ids)) ) ] - await session.execute( - update(models.User) - .where(models.User.id.in_(user_ids)) - .values(deleted_at=func.now()) - ) + await session.execute(delete(models.User).where(models.User.id.in_(user_ids))) await token_store.revoke( *password_reset_token_ids, *access_token_ids, @@ -319,9 +308,7 @@ def _select_role_id_by_name(role_name: str) -> Select[Tuple[int]]: def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]: return ( - select(models.User) - .where(and_(models.User.id == user_id, models.User.deleted_at.is_(None))) - .options(joinedload(models.User.role)) + select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) ) diff --git a/src/phoenix/server/api/queries.py b/src/phoenix/server/api/queries.py index 73efcca0d7..0c52856db8 100644 --- a/src/phoenix/server/api/queries.py +++ b/src/phoenix/server/api/queries.py @@ -96,12 +96,7 @@ async def users( stmt = ( select(models.User) .join(models.UserRole) - .where( - and_( - models.UserRole.name != enums.UserRole.SYSTEM.value, - models.User.deleted_at.is_(None), - ) - ) + .where(models.UserRole.name != enums.UserRole.SYSTEM.value) .order_by(models.User.email) .options(joinedload(models.User.role)) ) @@ -477,9 +472,7 @@ async def node(self, id: GlobalID, info: Info[Context, None]) -> Node: async with info.context.db() as session: if not ( user := await session.scalar( - select(models.User).where( - and_(models.User.id == node_id, models.User.deleted_at.is_(None)) - ) + select(models.User).where(models.User.id == node_id) ) ): raise NotFound(f"Unknown user: {id}") @@ -499,9 +492,7 @@ async def viewer(self, info: Info[Context, None]) -> Optional[User]: if ( user := await session.scalar( select(models.User) - .where( - and_(models.User.id == int(user.identity), models.User.deleted_at.is_(None)) - ) + .where(models.User.id == int(user.identity)) .options(joinedload(models.User.role)) ) ) is None: diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index efad5f40b8..83f7ba0438 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -2,10 +2,9 @@ import secrets from datetime import datetime, timedelta, timezone from functools import partial -from typing import Tuple from fastapi import APIRouter, Depends, HTTPException, Request, Response -from sqlalchemy import Select, select +from sqlalchemy import select from sqlalchemy.orm import joinedload from starlette.status import ( HTTP_204_NO_CONTENT, @@ -80,7 +79,7 @@ async def login(request: Request) -> Response: async with request.app.state.db() as session: user = await session.scalar( - _select_active_user().filter_by(email=email).options(joinedload(models.User.role)) + select(models.User).filter_by(email=email).options(joinedload(models.User.role)) ) if ( user is None @@ -160,7 +159,7 @@ async def refresh_tokens(request: Request) -> Response: async with request.app.state.db() as session: if ( user := await session.scalar( - _select_active_user().filter_by(id=user_id).options(joinedload(models.User.role)) + select(models.User).filter_by(id=user_id).options(joinedload(models.User.role)) ) ) is None: raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found") @@ -191,7 +190,7 @@ async def initiate_password_reset(request: Request) -> Response: assert isinstance(token_expiry := request.app.state.password_reset_token_expiry, timedelta) async with request.app.state.db() as session: user = await session.scalar( - _select_active_user() + select(models.User) .filter_by(email=email) .options( joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id) @@ -228,7 +227,7 @@ async def reset_password(request: Request) -> Response: raise INVALID_TOKEN assert (user_id := claims.subject) async with request.app.state.db() as session: - user = await session.scalar(_select_active_user().filter_by(id=int(user_id))) + user = await session.scalar(select(models.User).filter_by(id=int(user_id))) if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) @@ -249,10 +248,6 @@ async def reset_password(request: Request) -> Response: return response -def _select_active_user() -> Select[Tuple[models.User]]: - return select(models.User).where(models.User.deleted_at.is_(None)) - - LOGIN_FAILED_MESSAGE = "Invalid email and/or password" MISSING_EMAIL = HTTPException(