-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
integrated cognito auth into contract post
- Loading branch information
Showing
33 changed files
with
330 additions
and
558 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,51 @@ | ||
from flask import Flask, Response, request | ||
from flask_cors import CORS | ||
from flask_cognito import CognitoAuth | ||
from flasgger import Swagger | ||
from controllers import ContractController, MeController, HealthController | ||
from flask_restful import Api | ||
from fastapi import FastAPI, Request, Response, status | ||
from fastapi.responses import JSONResponse | ||
from time import strftime | ||
import logging | ||
from utilities.types import JSONDict | ||
from config.env import COGNITO_REGION, COGNITO_USERPOOL_ID, COGNITO_APP_CLIENT_ID | ||
from managers import MeManager | ||
from http import HTTPStatus | ||
from utilities.types import FlaskResponseType | ||
import traceback | ||
from database import CognitoIdentityProviderWrapper | ||
|
||
app = Flask(__name__) | ||
|
||
print(CognitoIdentityProviderWrapper().get_user("test")) | ||
|
||
app.config.update({ | ||
'COGNITO_REGION': COGNITO_REGION, | ||
'COGNITO_USERPOOL_ID': COGNITO_USERPOOL_ID, | ||
'COGNITO_APP_CLIENT_ID': COGNITO_APP_CLIENT_ID, | ||
|
||
# optional | ||
'COGNITO_CHECK_TOKEN_EXPIRATION': True | ||
}) | ||
|
||
app.config['SWAGGER'] = { | ||
'title': 'AADR Backend API' | ||
} | ||
|
||
|
||
cogauth = CognitoAuth(app) | ||
cogauth.init_app(app) | ||
CORS(app) | ||
Swagger(app) | ||
api = Api(app) | ||
import uvicorn | ||
from fastapi_cloudauth.cognito import Cognito | ||
from starlette.middleware.base import _StreamingResponse | ||
from typing import Awaitable, Callable | ||
|
||
app = FastAPI() | ||
auth = Cognito( | ||
region=COGNITO_REGION, | ||
userPoolId=COGNITO_USERPOOL_ID, | ||
client_id=COGNITO_APP_CLIENT_ID | ||
) | ||
|
||
logging.getLogger().setLevel(logging.INFO) | ||
api.add_resource(ContractController, '/contract') | ||
api.add_resource(MeController, "/me") | ||
api.add_resource(HealthController, "/health") | ||
|
||
|
||
@cogauth.identity_handler | ||
def lookup_cognito_user(payload: JSONDict) -> str: | ||
"""Look up user in our database from Cognito JWT payload.""" | ||
assert 'sub' in payload, "Invalid Cognito JWT payload" | ||
user_id = payload['sub'] | ||
|
||
me_manager = MeManager() | ||
user = me_manager.get_user_from_db(user_id) | ||
|
||
# Add database information to payload | ||
payload['database'] = user | ||
|
||
# ID tokens contain 'cognito:username' in payload instead of 'username' | ||
username = None | ||
if "cognito:username" in payload: | ||
username = payload['cognito:username'] | ||
elif "username" in payload: | ||
username = payload['username'] | ||
|
||
assert type(username) == str, "Invalid username" | ||
|
||
return username | ||
app.include_router(ContractController(auth).router) | ||
app.include_router(MeController(auth).router) | ||
app.include_router(HealthController(auth).router) | ||
|
||
|
||
@app.after_request | ||
def after_request(response: Response) -> Response: | ||
@app.middleware("http") | ||
async def after_request(request: Request, call_next: Callable[..., Awaitable[_StreamingResponse]]) -> Response: | ||
response: Response = await call_next(request) | ||
timestamp = strftime('[%Y-%b-%d %H:%M]') # TODO this is defined in multiple spots. Make robust | ||
logging.info('%s %s %s %s %s %s', timestamp, request.remote_addr, request.method, request.scheme, request.full_path, response.status) | ||
assert request.client, "Missing header data in request. No client information." | ||
logging.info('%s %s %s %s %s %s', timestamp, request.client.host, request.method, request.scope['type'], request.url, response.status_code) | ||
return response | ||
|
||
|
||
# @app.errorhandler(Exception) # type: ignore[type-var] | ||
def exceptions(e: Exception) -> FlaskResponseType: | ||
@app.exception_handler(Exception) | ||
def exceptions(request: Request, e: Exception) -> JSONResponse: | ||
tb = traceback.format_exc() | ||
timestamp = strftime('[%Y-%b-%d %H:%M]') | ||
logging.error('%s %s %s %s %s 5xx INTERNAL SERVER ERROR\n%s', timestamp, request.remote_addr, request.method, request.scheme, request.full_path, tb) | ||
assert request.client, "Missing header data in request. No client information." | ||
logging.error('%s %s %s %s %s 5xx INTERNAL SERVER ERROR\n%s', timestamp, request.client.host, request.method, request.scope['type'], request.url, tb) | ||
logging.error(e) | ||
return "Internal server error", HTTPStatus.INTERNAL_SERVER_ERROR | ||
return JSONResponse( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
content=None | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.getLogger().setLevel(logging.DEBUG) | ||
app.run(debug=True, host="0.0.0.0", port=3001) | ||
uvicorn.run(app, host="0.0.0.0", port=3001) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
from .contract import ContractController | ||
from .base_controller import BaseController | ||
from .swagger import * | ||
from .health import HealthController | ||
from .me import MeController |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,18 @@ | ||
from flask import Response, abort | ||
from http import HTTPStatus | ||
from flask_restful import Resource | ||
from flask import request | ||
from flasgger import validate | ||
import json | ||
from typing import Dict, Any, Union | ||
from utilities.types import JSONDict | ||
from jsonschema.exceptions import ValidationError | ||
import logging | ||
from flask_cognito import cognito_auth_required, current_cognito_jwt | ||
from fastapi import APIRouter | ||
from time import strftime | ||
from fastapi_cloudauth.cognito import Cognito | ||
from fastapi import Request | ||
|
||
|
||
class BaseController(Resource): # type: ignore[no-any-unimported] | ||
class BaseController: | ||
|
||
@classmethod | ||
def log_debug(cls, msg: str) -> None: | ||
timestamp = strftime('[%Y-%b-%d %H:%M]') | ||
logging.debug('%s %s %s %s %s %s', timestamp, request.remote_addr, request.method, request.scheme, request.full_path, msg) | ||
def __init__(self, auth: Cognito): # type: ignore[no-any-unimported] | ||
self.router = APIRouter() | ||
self.auth = auth | ||
|
||
@classmethod | ||
def get_request_data(cls, swagger_data: Union[str, JSONDict], swagger_object_id: str) -> Dict[str, Any]: | ||
""" | ||
Gets and verifies request data. | ||
It is preferred to use a .yaml str filepath for swagger_data, | ||
but for dynamic swagger API's based on configs, use a dictionary of the spec | ||
""" | ||
data = request.get_json() | ||
assert type(data) == dict, "Invalid data in request" | ||
cls.log_debug(json.dumps(data)) | ||
if type(swagger_data) is dict: | ||
validate(data, swagger_object_id, specs=swagger_data, validation_error_handler=cls.error_handler) | ||
else: | ||
validate(data, swagger_object_id, swagger_data, validation_error_handler=cls.error_handler) | ||
return data | ||
|
||
@classmethod | ||
def abort_request(cls, message: str, status: int) -> None: | ||
abort(Response(json.dumps({'error': message}), status=status)) | ||
|
||
@classmethod | ||
def error_handler(cls, err: ValidationError, data: JSONDict, schema: JSONDict) -> None: | ||
""" | ||
Error handler for flasgger | ||
""" | ||
error_message = str(err.message) | ||
cls.log_debug(error_message) | ||
cls.abort_request(error_message, HTTPStatus.BAD_REQUEST) | ||
|
||
@classmethod | ||
@cognito_auth_required | ||
def verify_id_token(cls) -> None: | ||
""" | ||
Returns 400 if header token is not id | ||
""" | ||
if current_cognito_jwt['token_use'] != "id": | ||
cls.abort_request("Header must contain an ID token", HTTPStatus.BAD_REQUEST) | ||
def log_debug(cls, msg: str, request: Request) -> None: | ||
timestamp = strftime('[%Y-%b-%d %H:%M]') | ||
assert request.client, "Missing header data in request. No client information." | ||
logging.debug('%s %s %s %s %s %s', timestamp, request.client.host, request.method, request.scope['type'], request.url, msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,63 @@ | ||
from flask_cognito import cognito_auth_required, current_cognito_jwt, current_user | ||
from managers import ContractManager | ||
from flasgger import swag_from | ||
from .base_controller import BaseController | ||
from utilities.types import FlaskResponseType | ||
from utilities import FlaskResponses, NoApproverException | ||
from .swagger.contract.post import contract_post_schema | ||
from utilities.types import JSONDict | ||
from utilities import NoApproverException | ||
from utilities.types import HelperModel | ||
from typing import Optional | ||
from fastapi_cloudauth.cognito import Cognito | ||
from fastapi_cloudauth.cognito import CognitoClaims | ||
from utilities.auth import get_current_user | ||
from fastapi import status, Depends, Response, HTTPException | ||
from fastapi.responses import JSONResponse | ||
from pydantic import BaseModel, Field | ||
from utilities.types.fields import phone_number | ||
from config import Config | ||
from typing import List | ||
from database.users import UsersDB | ||
|
||
config = Config() | ||
|
||
|
||
class PostItem(BaseModel): | ||
artist_phone_number: int = phone_number("artistPhoneNumber") | ||
helpers: Optional[List[HelperModel]] = Field(alias="helpers", min_length=1, max_length=config.get_contract_limit("max_helpers")) | ||
num_additional_chairs: int = Field(alias="numAdditionalChairs", le=config.get_contract_limit("max_additional_chairs"), ge=0, examples=['2']) | ||
|
||
|
||
class PostResponseItem(BaseModel): | ||
contractId: int = 0 | ||
|
||
|
||
class ContractController(BaseController): | ||
|
||
@cognito_auth_required | ||
@swag_from(contract_post_schema) | ||
def post(self) -> FlaskResponseType: | ||
data = self.get_request_data(contract_post_schema, "ContractData") | ||
def __init__(self, auth: Cognito): # type: ignore[no-any-unimported] | ||
super().__init__(auth) | ||
self.router.add_api_route("/contract", self.post, methods=["POST"], response_model=PostResponseItem) | ||
|
||
user_db: Optional[JSONDict] = current_cognito_jwt['database'] | ||
if user_db is None: | ||
return FlaskResponses.bad_request("User needs to make an account") | ||
async def post(self, item: PostItem, current_user: CognitoClaims = Depends(get_current_user)) -> Response: # type: ignore[no-any-unimported] | ||
db = await UsersDB.get_user(current_user.sub) | ||
if not db: | ||
raise HTTPException( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail='User needs to make an account' | ||
) | ||
|
||
try: | ||
result = ContractManager().create_contract( | ||
current_cognito_jwt['sub'], | ||
contract_type=user_db['vendor_type'], | ||
helpers=data.get('helpers'), | ||
num_additional_chairs=data['numAdditionalChairs'], | ||
signer_email=current_cognito_jwt['email'], # TODO assert that emails are verified | ||
signer_name=str(current_user), | ||
artist_phone_number=data['artistPhoneNumber'] | ||
result = await ContractManager().create_contract( | ||
current_user.sub, | ||
contract_type=str(db.get("vendor_type")), | ||
helpers=item.helpers, | ||
num_additional_chairs=item.num_additional_chairs, | ||
signer_email=current_user.email, # TODO assert that emails are verified | ||
signer_name=current_user.username, # TODO signer_name should be the user's name, not username | ||
artist_phone_number=item.artist_phone_number # TODO this should be stored in AWS | ||
) | ||
except NoApproverException: | ||
return FlaskResponses.conflict("Cannot make contract since there is nobody to approve the contract.") | ||
raise HTTPException( | ||
status_code=status.HTTP_409_CONFLICT, | ||
detail='Cannot make contract since there is nobody to approve the contract' | ||
) | ||
|
||
return FlaskResponses.success(result) | ||
return JSONResponse( | ||
status_code=status.HTTP_200_OK, | ||
content=result | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,17 @@ | ||
from flasgger import swag_from | ||
from .base_controller import BaseController | ||
from utilities.types import FlaskResponseType | ||
from utilities import FlaskResponses | ||
from fastapi import status, Response | ||
from fastapi_cloudauth.cognito import Cognito | ||
from fastapi.responses import JSONResponse | ||
|
||
|
||
class HealthController(BaseController): | ||
|
||
@swag_from("swagger/health/get.yaml") | ||
def get(self) -> FlaskResponseType: | ||
return FlaskResponses().success("ok") | ||
def __init__(self, auth: Cognito): # type: ignore[no-any-unimported] | ||
super().__init__(auth) | ||
self.router.add_api_route("/health", self.get, methods=["GET"], response_model=None) | ||
|
||
def get(self) -> Response: | ||
return JSONResponse( | ||
status_code=status.HTTP_200_OK, | ||
content=None | ||
) |
Oops, something went wrong.