From 6cc2d079b9ecdedf86c5115d37eebf920e7efd40 Mon Sep 17 00:00:00 2001 From: FireFading Date: Tue, 17 Oct 2023 00:25:13 +0300 Subject: [PATCH] refresh token --- app/controllers/users.py | 32 ++++++++++++++++++++++---------- app/database.py | 4 ++-- app/routers/users.py | 39 ++++++++++++++++++++++++++------------- app/schemas/users.py | 1 + app/settings.py | 7 +++---- 5 files changed, 54 insertions(+), 29 deletions(-) diff --git a/app/controllers/users.py b/app/controllers/users.py index 90f8b1d..8e42ffb 100644 --- a/app/controllers/users.py +++ b/app/controllers/users.py @@ -2,14 +2,14 @@ from typing import Annotated from app.config import jwt_settings +from app.database import with_async_session +from app.models.users import User as UserModel from app.schemas.tokens import TokenData from app.schemas.users import UserCreate as UserCreateSchema from app.services.users import UsersService, users_service from fastapi import Depends, HTTPException, status from jose import JWTError, jwt -from app.models.users import User as UserModel from sqlalchemy.ext.asyncio import AsyncSession -from app.database import with_async_session class UsersController: @@ -50,11 +50,12 @@ async def authenticate_user( return user @with_async_session - async def get_current_user( + async def verify_token( self, token: str, session: AsyncSession | None = None, - ) -> UserModel | HTTPException: + token_type: str = "access", + ): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -62,17 +63,28 @@ async def get_current_user( ) try: payload = jwt.decode( - token, jwt_settings.secret_key, algorithms=[jwt_settings.algorithm] + token=token, + key=jwt_settings.secret_key + if token_type == "access" + else jwt_settings.refresh_secret_key, + algorithms=[jwt_settings.jwt_algorithm], ) + + token_exp = payload.get("exp") + if not token_exp: + raise credentials_exception + + now = datetime.now(timezone.utc) + if now > datetime.fromtimestamp(token_exp): + raise credentials_exception + username = payload.get("sub") if not username: raise credentials_exception - token_data = TokenData(username=username) + except JWTError: raise credentials_exception - user = await self.users_service.get_user( - username=token_data.username, session=session - ) + user = await self.users_service.get_user(username=username, session=session) if not user: raise credentials_exception return user @@ -103,7 +115,7 @@ def create_refresh_token(self, subject: str) -> str: to_encode = {"exp": expires_delta, "sub": subject} return jwt.encode( - to_encode, jwt_settings.jwt_refresh_secret_key, jwt_settings.algorithm + to_encode, jwt_settings.refresh_secret_key, jwt_settings.algorithm ) diff --git a/app/database.py b/app/database.py index b8e4575..5cb814f 100644 --- a/app/database.py +++ b/app/database.py @@ -1,9 +1,9 @@ from collections.abc import AsyncGenerator +from functools import wraps from app.config import database_settings from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import declarative_base, sessionmaker -from functools import wraps engine = create_async_engine(database_settings.database_url, future=True, echo=True) async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) @@ -23,4 +23,4 @@ async def wrapper(*args, **kwargs): kwargs["session"] = session return await func(*args, **kwargs) - return wrapper \ No newline at end of file + return wrapper diff --git a/app/routers/users.py b/app/routers/users.py index 0551a1b..8da48cf 100644 --- a/app/routers/users.py +++ b/app/routers/users.py @@ -1,17 +1,9 @@ -from typing import Annotated - from app.controllers.users import users_controller -from app.models.users import User as UserModel -from app.schemas.users import ( - User as UserSchema, - UserCreate as UserCreateSchema, - UserShow as UserShowSchema, -) from app.schemas.tokens import Token -from fastapi import APIRouter, Depends, status, Response -from fastapi.security import OAuth2PasswordRequestForm -from fastapi.security import OAuth2PasswordBearer - +from app.schemas.users import UserCreate as UserCreateSchema +from app.schemas.users import UserShow as UserShowSchema +from fastapi import APIRouter, Depends, Response, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm router = APIRouter( prefix="/users", tags=["users"], responses={404: {"description": "Not found"}} @@ -49,6 +41,27 @@ async def login( } +@router.post( + "/refresh-token/", + response_model=Token, + status_code=status.HTTP_200_OK, + summary="Refresh Access Token", +) +async def refresh_token(response: Response, refresh_token: str): + user = await users_controller.verify_token( + token=refresh_token, token_type="refresh" + ) + access_token = users_controller.create_access_token(subject=user.username) + new_refresh_token = users_controller.create_refresh_token(subject=user.username) + + response.set_cookie(key="access_token", value=access_token, httponly=True) + response.set_cookie(key="refresh_token", value=new_refresh_token, httponly=True) + + return { + "access_token": access_token, + } + + @router.get( "/users/me/", response_model=UserShowSchema, @@ -56,4 +69,4 @@ async def login( summary="Get user info", ) async def current_user(token: str = Depends(oauth2_scheme)): - return await users_controller.get_current_user(token=token) + return await users_controller.verify_token(token=token) diff --git a/app/schemas/users.py b/app/schemas/users.py index 309524f..a6c6660 100644 --- a/app/schemas/users.py +++ b/app/schemas/users.py @@ -20,5 +20,6 @@ class UserShow(BaseModel): email: str | None = None full_name: str | None = None + class UserInDB(User): hashed_password: str diff --git a/app/settings.py b/app/settings.py index 820f280..bb03fbe 100644 --- a/app/settings.py +++ b/app/settings.py @@ -5,7 +5,7 @@ class PostgresSettings(BaseSettings): - model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra='allow') + model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra="allow") database_url: str @@ -16,11 +16,10 @@ class PostgresSettings(BaseSettings): class JWTSettings(BaseSettings): - model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra='allow') + model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra="allow") secret_key: str algorithm: str access_token_expire_minutes: int refresh_token_expire_minutes: int - jwt_refresh_secret_key: str - + refresh_secret_key: str