From aebdcd0332f27da6be21a33bd712d5132e8fb648 Mon Sep 17 00:00:00 2001 From: Jeny Sadadia Date: Sat, 14 Oct 2023 11:03:32 +0530 Subject: [PATCH] api.auth: optimize `Authentication` class Now the authentication methods have been used from `fastapi-users` package. Hence, remove all the methods from `Authentication` class related to access tokens, scopes, passwords, and getting current user. Also update signature of the class constructor to drop unused variables. Rename variable `Settings.access_token_expire_minutes` to `Settings.access_token_expire_seconds` as `JWTStrategy` class takes token expiry time in seconds. Update logic to create an instance of `Authentication` class in `api.main` module. Signed-off-by: Jeny Sadadia --- api/auth.py | 110 ++++------------------------------------------------ api/main.py | 4 +- 2 files changed, 9 insertions(+), 105 deletions(-) diff --git a/api/auth.py b/api/auth.py index 18c3e9c14..ad616cb72 100644 --- a/api/auth.py +++ b/api/auth.py @@ -6,28 +6,13 @@ """User authentication utilities""" -from datetime import datetime, timedelta -from fastapi.security import OAuth2PasswordBearer -from jose import JWTError, jwt from passlib.context import CryptContext -from pydantic import BaseModel, BaseSettings, Field +from pydantic import BaseSettings from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, JWTStrategy, ) -from .db import Database -from .models import User - - -class Token(BaseModel): - """Authentication token model""" - access_token: str = Field( - description='Authentication access token' - ) - token_type: str = Field( - description='Access token type e.g. Bearer' - ) class Settings(BaseSettings): @@ -35,108 +20,29 @@ class Settings(BaseSettings): secret_key: str algorithm: str = "HS256" # Set to None so tokens don't expire - access_token_expire_minutes: float = None + access_token_expire_seconds: int = None class Authentication: - """Authentication utility class - - This class accepts a single argument `database` in its constructor, which - should be a db.Database object. - """ + """Authentication utility class""" CRYPT_CTX = CryptContext(schemes=["bcrypt"], deprecated="auto") - def __init__(self, database: Database, token_url: str, user_scopes: dict): - self._db = database + def __init__(self, token_url: str): self._settings = Settings() - self._user_scopes = user_scopes - self._oauth2_scheme = OAuth2PasswordBearer( - tokenUrl=token_url, - scopes=self._user_scopes - ) - - @property - def oauth2_scheme(self): - """Get authentication scheme""" - return self._oauth2_scheme + self._token_url = token_url @classmethod def get_password_hash(cls, password): """Get a password hash for a given clear text password string""" return cls.CRYPT_CTX.hash(password) - @classmethod - def verify_password(cls, password_hash, user): - """Verify that the password hash matches the user's password""" - return cls.CRYPT_CTX.verify(password_hash, user.hashed_password) - - async def authenticate_user(self, username: str, password: str): - """Authenticate a username / password pair - - Look up a `User` in the database with the provided `username` - and check whether the provided clear text `password` matches the hash - associated with it. - """ - user = await self._db.find_one_by_attributes( - User, {'profile.username': username}) - if not user: - return False - if not self.verify_password(password, user.profile): - return False - return user.profile - - def create_access_token(self, data: dict): - """Create a JWT access token using the provided arbitrary `data`""" - to_encode = data.copy() - if self._settings.access_token_expire_minutes: - expires_delta = timedelta( - minutes=self._settings.access_token_expire_minutes - ) - expire = datetime.utcnow() + expires_delta - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode( - to_encode, - self._settings.secret_key, algorithm=self._settings.algorithm - ) - return encoded_jwt - - async def get_current_user(self, token, security_scopes): - """Decode the given JWT `token` and look up a matching `User`""" - try: - payload = jwt.decode( - token, - self._settings.secret_key, - algorithms=[self._settings.algorithm] - ) - username: str = payload.get("sub") - token_scopes = payload.get("scopes", []) - if username is None: - return None, "Could not validate credentials" - - for scope in security_scopes: - if scope not in token_scopes: - return None, "Access denied" - - except JWTError as error: - return None, str(error) - - user = await self._db.find_one_by_attributes( - User, {'profile.username': username}) - return user, None - - async def validate_scopes(self, requested_scopes): - """Check if requested scopes are valid user scopes""" - for scope in requested_scopes: - if scope not in self._user_scopes: - return False, scope - return True, None - def get_jwt_strategy(self) -> JWTStrategy: """Get JWT strategy for authentication backend""" return JWTStrategy( secret=self._settings.secret_key, - lifetime_seconds=self._settings.access_token_expire_minutes + algorithm=self._settings.algorithm, + lifetime_seconds=self._settings.access_token_expire_seconds ) def get_user_authentication_backend(self): @@ -148,7 +54,7 @@ def get_user_authentication_backend(self): Strategy is a method to generate and secure tokens. It can be JWT, database or Redis. """ - bearer_transport = BearerTransport(tokenUrl="user/login") + bearer_transport = BearerTransport(tokenUrl=self._token_url) return AuthenticationBackend( name="jwt", transport=bearer_transport, diff --git a/api/main.py b/api/main.py index 10429189b..83637ff4a 100644 --- a/api/main.py +++ b/api/main.py @@ -48,9 +48,7 @@ app = FastAPI() db = Database(service=(os.getenv('MONGO_SERVICE') or 'mongodb://db:27017')) -auth = Authentication(db, token_url='token', - user_scopes={"admin": "Superusers", - "users": "Regular users"}) +auth = Authentication(token_url="user/login") pubsub = None # pylint: disable=invalid-name auth_backend = auth.get_user_authentication_backend() fastapi_users_instance = FastAPIUsers[User, PydanticObjectId](