Skip to content

Commit

Permalink
refresh token
Browse files Browse the repository at this point in the history
  • Loading branch information
FireFading committed Oct 16, 2023
1 parent 9ab7042 commit 6cc2d07
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 29 deletions.
32 changes: 22 additions & 10 deletions app/controllers/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from typing import Annotated

from app.config import jwt_settings
from app.database import with_async_session
from app.models.users import User as UserModel
from app.schemas.tokens import TokenData
from app.schemas.users import UserCreate as UserCreateSchema
from app.services.users import UsersService, users_service
from fastapi import Depends, HTTPException, status
from jose import JWTError, jwt
from app.models.users import User as UserModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import with_async_session


class UsersController:
Expand Down Expand Up @@ -50,29 +50,41 @@ async def authenticate_user(
return user

@with_async_session
async def get_current_user(
async def verify_token(
self,
token: str,
session: AsyncSession | None = None,
) -> UserModel | HTTPException:
token_type: str = "access",
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, jwt_settings.secret_key, algorithms=[jwt_settings.algorithm]
token=token,
key=jwt_settings.secret_key
if token_type == "access"
else jwt_settings.refresh_secret_key,
algorithms=[jwt_settings.jwt_algorithm],
)

token_exp = payload.get("exp")
if not token_exp:
raise credentials_exception

now = datetime.now(timezone.utc)
if now > datetime.fromtimestamp(token_exp):
raise credentials_exception

username = payload.get("sub")
if not username:
raise credentials_exception
token_data = TokenData(username=username)

except JWTError:
raise credentials_exception
user = await self.users_service.get_user(
username=token_data.username, session=session
)
user = await self.users_service.get_user(username=username, session=session)
if not user:
raise credentials_exception
return user
Expand Down Expand Up @@ -103,7 +115,7 @@ def create_refresh_token(self, subject: str) -> str:

to_encode = {"exp": expires_delta, "sub": subject}
return jwt.encode(
to_encode, jwt_settings.jwt_refresh_secret_key, jwt_settings.algorithm
to_encode, jwt_settings.refresh_secret_key, jwt_settings.algorithm
)


Expand Down
4 changes: 2 additions & 2 deletions app/database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from collections.abc import AsyncGenerator
from functools import wraps

from app.config import database_settings
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import declarative_base, sessionmaker
from functools import wraps

engine = create_async_engine(database_settings.database_url, future=True, echo=True)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
Expand All @@ -23,4 +23,4 @@ async def wrapper(*args, **kwargs):
kwargs["session"] = session
return await func(*args, **kwargs)

return wrapper
return wrapper
39 changes: 26 additions & 13 deletions app/routers/users.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
from typing import Annotated

from app.controllers.users import users_controller
from app.models.users import User as UserModel
from app.schemas.users import (
User as UserSchema,
UserCreate as UserCreateSchema,
UserShow as UserShowSchema,
)
from app.schemas.tokens import Token
from fastapi import APIRouter, Depends, status, Response
from fastapi.security import OAuth2PasswordRequestForm
from fastapi.security import OAuth2PasswordBearer

from app.schemas.users import UserCreate as UserCreateSchema
from app.schemas.users import UserShow as UserShowSchema
from fastapi import APIRouter, Depends, Response, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm

router = APIRouter(
prefix="/users", tags=["users"], responses={404: {"description": "Not found"}}
Expand Down Expand Up @@ -49,11 +41,32 @@ async def login(
}


@router.post(
"/refresh-token/",
response_model=Token,
status_code=status.HTTP_200_OK,
summary="Refresh Access Token",
)
async def refresh_token(response: Response, refresh_token: str):
user = await users_controller.verify_token(
token=refresh_token, token_type="refresh"
)
access_token = users_controller.create_access_token(subject=user.username)
new_refresh_token = users_controller.create_refresh_token(subject=user.username)

response.set_cookie(key="access_token", value=access_token, httponly=True)
response.set_cookie(key="refresh_token", value=new_refresh_token, httponly=True)

return {
"access_token": access_token,
}


@router.get(
"/users/me/",
response_model=UserShowSchema,
status_code=status.HTTP_200_OK,
summary="Get user info",
)
async def current_user(token: str = Depends(oauth2_scheme)):
return await users_controller.get_current_user(token=token)
return await users_controller.verify_token(token=token)
1 change: 1 addition & 0 deletions app/schemas/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ class UserShow(BaseModel):
email: str | None = None
full_name: str | None = None


class UserInDB(User):
hashed_password: str
7 changes: 3 additions & 4 deletions app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class PostgresSettings(BaseSettings):
model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra='allow')
model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra="allow")

database_url: str

Expand All @@ -16,11 +16,10 @@ class PostgresSettings(BaseSettings):


class JWTSettings(BaseSettings):
model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra='allow')
model_config = SettingsConfigDict(env_file="./.env.example", env_file_encoding="utf-8", extra="allow")

secret_key: str
algorithm: str
access_token_expire_minutes: int
refresh_token_expire_minutes: int
jwt_refresh_secret_key: str

refresh_secret_key: str

0 comments on commit 6cc2d07

Please sign in to comment.