Skip to content

Commit

Permalink
update user info
Browse files Browse the repository at this point in the history
  • Loading branch information
FireFading committed Oct 16, 2023
1 parent 6cc2d07 commit fde6ed1
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 56 deletions.
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ SECRET_KEY="secret"
ALGORITHM="HS256"
ACCESS_TOKEN_EXPIRE_MINUTES=30
REFRESH_TOKEN_EXPIRE_MINUTES=600
JWT_REFRESH_SECRET_KEY="secret"
REFRESH_SECRET_KEY="secret"
58 changes: 20 additions & 38 deletions app/controllers/users.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,37 @@
from datetime import datetime, timedelta, timezone
from typing import Annotated

from app.config import jwt_settings
from app.database import with_async_session
from app.models.users import User as UserModel
from app.schemas.tokens import TokenData
from app.schemas.users import UserCreate as UserCreateSchema
from app.schemas.users import UserUpdate as UserUpdateSchema
from app.services.users import UsersService, users_service
from fastapi import Depends, HTTPException, status
from fastapi import HTTPException, status
from jose import JWTError, jwt
from sqlalchemy.ext.asyncio import AsyncSession


class UsersController:
def __init__(self, users_service: UsersService) -> None:
self.users_service = users_service

@with_async_session
async def register(
self,
user_schema: UserCreateSchema,
session: AsyncSession | None = None,
) -> UserModel | HTTPException:
email = user_schema.email
if await self.users_service.get_user(session=session, email=email):
if await self.users_service.get_user(email=email):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="User with this email already exists",
)
user = UserModel(**user_schema.model_dump())
return await self.users_service.create(user=user, session=session)
return await self.users_service.create(user=user)

@with_async_session
async def authenticate_user(
self,
username: str,
password: str,
session: AsyncSession | None,
) -> UserModel | HTTPException:
user = await self.users_service.authenticate_user(
username=username, password=password, session=session
)
user = await self.users_service.authenticate_user(username=username, password=password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -49,11 +40,14 @@ async def authenticate_user(
)
return user

@with_async_session
async def update_user_info(self, user: UserModel, update_user_schema: UserUpdateSchema) -> UserModel:
user.full_name = update_user_schema.full_name
user.email = update_user_schema.email
return await self.users_service.update_info(user=user)

async def verify_token(
self,
token: str,
session: AsyncSession | None = None,
token_type: str = "access",
):
credentials_exception = HTTPException(
Expand All @@ -64,18 +58,16 @@ async def verify_token(
try:
payload = jwt.decode(
token=token,
key=jwt_settings.secret_key
if token_type == "access"
else jwt_settings.refresh_secret_key,
algorithms=[jwt_settings.jwt_algorithm],
key=jwt_settings.secret_key if token_type == "access" else jwt_settings.refresh_secret_key,
algorithms=[jwt_settings.algorithm],
)

token_exp = payload.get("exp")
if not token_exp:
raise credentials_exception

now = datetime.now(timezone.utc)
if now > datetime.fromtimestamp(token_exp):
if now > datetime.fromtimestamp(token_exp, tz=timezone.utc):
raise credentials_exception

username = payload.get("sub")
Expand All @@ -84,39 +76,29 @@ async def verify_token(

except JWTError:
raise credentials_exception
user = await self.users_service.get_user(username=username, session=session)
user = await self.users_service.get_user(username=username)
if not user:
raise credentials_exception
return user

async def get_current_active_user(
self,
current_user: Annotated[UserModel, Depends(get_current_user)],
current_user: UserModel,
) -> UserModel | HTTPException:
if current_user.disabled:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user"
)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Inactive user")
return current_user

def create_access_token(self, subject: str) -> str:
expires_delta = datetime.now(timezone.utc) + timedelta(
minutes=jwt_settings.access_token_expire_minutes
)
expires_delta = datetime.now(timezone.utc) + timedelta(minutes=jwt_settings.access_token_expire_minutes)
to_encode = {"exp": expires_delta, "sub": subject}
return jwt.encode(
to_encode, jwt_settings.secret_key, algorithm=jwt_settings.algorithm
)
return jwt.encode(to_encode, jwt_settings.secret_key, algorithm=jwt_settings.algorithm)

def create_refresh_token(self, subject: str) -> str:
expires_delta = datetime.now(timezone.utc) + timedelta(
minutes=jwt_settings.refresh_token_expire_minutes
)
expires_delta = datetime.now(timezone.utc) + timedelta(minutes=jwt_settings.refresh_token_expire_minutes)

to_encode = {"exp": expires_delta, "sub": subject}
return jwt.encode(
to_encode, jwt_settings.refresh_secret_key, jwt_settings.algorithm
)
return jwt.encode(to_encode, jwt_settings.refresh_secret_key, jwt_settings.algorithm)


users_controller = UsersController(users_service=users_service)
35 changes: 21 additions & 14 deletions app/routers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
from app.schemas.tokens import Token
from app.schemas.users import UserCreate as UserCreateSchema
from app.schemas.users import UserShow as UserShowSchema
from app.schemas.users import UserUpdate as UserUpdateSchema
from fastapi import APIRouter, Depends, Response, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

router = APIRouter(
prefix="/users", tags=["users"], responses={404: {"description": "Not found"}}
)
router = APIRouter(prefix="/users", tags=["users"], responses={404: {"description": "Not found"}})

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/users/login/", scheme_name="JWT")

Expand All @@ -22,16 +21,12 @@ async def register(user_schema: UserCreateSchema):
return await users_controller.register(user_schema=user_schema)


@router.post(
"/login/", response_model=Token, status_code=status.HTTP_200_OK, summary="Login"
)
@router.post("/login/", response_model=Token, status_code=status.HTTP_200_OK, summary="Login")
async def login(
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
user = await users_controller.authenticate_user(
username=form_data.username, password=form_data.password
)
user = await users_controller.authenticate_user(username=form_data.username, password=form_data.password)
access_token = users_controller.create_access_token(subject=user.username)
refresh_token = users_controller.create_refresh_token(subject=user.username)
response.set_cookie(key="access_token", value=f"{access_token}", httponly=True)
Expand All @@ -48,9 +43,7 @@ async def login(
summary="Refresh Access Token",
)
async def refresh_token(response: Response, refresh_token: str):
user = await users_controller.verify_token(
token=refresh_token, token_type="refresh"
)
user = await users_controller.verify_token(token=refresh_token, token_type="refresh")
access_token = users_controller.create_access_token(subject=user.username)
new_refresh_token = users_controller.create_refresh_token(subject=user.username)

Expand All @@ -63,10 +56,24 @@ async def refresh_token(response: Response, refresh_token: str):


@router.get(
"/users/me/",
"/me/",
response_model=UserShowSchema,
status_code=status.HTTP_200_OK,
summary="Get user info",
)
async def current_user(token: str = Depends(oauth2_scheme)):
async def get_current_user(token: str = Depends(oauth2_scheme)):
return await users_controller.verify_token(token=token)


@router.put(
"/update/",
response_model=UserShowSchema,
status_code=status.HTTP_200_OK,
summary="Update User Info",
)
async def update_user_info(
user_update: UserUpdateSchema,
token: str = Depends(oauth2_scheme),
):
current_user = await users_controller.verify_token(token=token)
return await users_controller.update_user_info(user=current_user, update_user_schema=user_update)
5 changes: 5 additions & 0 deletions app/schemas/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,10 @@ class UserShow(BaseModel):
full_name: str | None = None


class UserUpdate(BaseModel):
email: str | None = None
full_name: str | None


class UserInDB(User):
hashed_password: str
14 changes: 11 additions & 3 deletions app/services/users.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from app.database import with_async_session
from app.models.users import User as UserModel
from app.repositories.users import UsersRepository, users_repository
from passlib.context import CryptContext
Expand All @@ -10,19 +11,26 @@ class UsersService:
def __init__(self, users_repository: UsersRepository) -> None:
self.users_repository = users_repository

async def get_user(self, session: AsyncSession, **kwargs) -> UserModel | None:
@with_async_session
async def get_user(self, session: AsyncSession | None = None, **kwargs) -> UserModel | None:
return await self.users_repository.get(session=session, **kwargs)

async def create(self, user: UserModel, session: AsyncSession) -> UserModel:
@with_async_session
async def create(self, user: UserModel, session: AsyncSession | None = None) -> UserModel:
user.password = self.get_password_hash(user.password)
return await self.users_repository.create(instance=user, session=session)

async def authenticate_user(self, username: str, password: str, session: AsyncSession) -> bool:
@with_async_session
async def authenticate_user(self, username: str, password: str, session: AsyncSession | None = None) -> bool:
user = await self.users_repository.get(session=session, username=username)
if not user:
return False
return user if self.verify_password(plain_password=password, hashed_password=user.password) else False

@with_async_session
async def update_info(self, user: UserModel, session: AsyncSession | None = None) -> UserModel:
return await self.users_repository.update(instance=user, session=session)

def verify_password(self, plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)

Expand Down

0 comments on commit fde6ed1

Please sign in to comment.