Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: Implement authentication for API v2 based on OAuth2 potocol. #388

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
35 changes: 35 additions & 0 deletions app/api/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
# isort: dont-add-imports

from typing import Any

from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import status

from app.api.v2.common.oauth import OAuth2Scheme
from app.repositories import access_tokens as access_tokens_repo

oauth2_scheme = OAuth2Scheme(
authorizationUrl="/v2/oauth/authorize",
tokenUrl="/v2/oauth/token",
refreshUrl="/v2/oauth/refresh",
scheme_name="OAuth2 for third-party clients.",
scopes={
"public": "Access endpoints with public data.",
"identify": "Access endpoints with user's data.",
"admin": "Access admin endpoints.",
},
)


async def get_current_client(token: str = Depends(oauth2_scheme)) -> dict[str, Any]:
cmyui marked this conversation as resolved.
Show resolved Hide resolved
"""Look up the token in the Redis-based token store"""
access_token = await access_tokens_repo.fetch_one(token)
if not access_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
return access_token


from . import clans
from . import maps
from . import oauth
from . import players
from . import scores

Expand All @@ -13,3 +47,4 @@
apiv2_router.include_router(maps.router)
apiv2_router.include_router(players.router)
apiv2_router.include_router(scores.router)
apiv2_router.include_router(oauth.router)
9 changes: 8 additions & 1 deletion app/api/v2/common/json.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any
from uuid import UUID

import orjson
from fastapi.responses import JSONResponse
Expand All @@ -14,6 +15,8 @@ def _default_processor(data: Any) -> Any:
return {k: _default_processor(v) for k, v in data.items()}
elif isinstance(data, list):
return [_default_processor(v) for v in data]
elif isinstance(data, UUID):
return str(data)
else:
return data

Expand All @@ -22,8 +25,12 @@ def dumps(data: Any) -> bytes:
return orjson.dumps(data, default=_default_processor)


def loads(data: str) -> Any:
return orjson.loads(data)


class ORJSONResponse(JSONResponse):
media_type = "application/json"
media_type = "application/json;charset=UTF-8"

def render(self, content: Any) -> bytes:
return dumps(content)
85 changes: 85 additions & 0 deletions app/api/v2/common/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations
Copy link
Member Author

@alowave223 alowave223 Jan 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import breaks FastApi's Depends. Related issue: fastapi/fastapi#1654

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i ran into this a long while ago at work, nice work actually tracking it down lol, that's not an easy error


import base64

from fastapi import Request
from fastapi import status
from fastapi.exceptions import HTTPException
from fastapi.openapi.models import OAuthFlowAuthorizationCode
from fastapi.openapi.models import OAuthFlowClientCredentials
from fastapi.openapi.models import OAuthFlows
from fastapi.security import OAuth2
from fastapi.security.utils import get_authorization_scheme_param


class OAuth2Scheme(OAuth2):
def __init__(
self,
authorizationUrl: str,
tokenUrl: str,
refreshUrl: str | None = None,
scheme_name: str | None = None,
scopes: dict[str, str] | None = None,
description: str | None = None,
auto_error: bool = True,
):
if not scopes:
scopes = {}
flows = OAuthFlows(
authorizationCode=OAuthFlowAuthorizationCode(
authorizationUrl=authorizationUrl,
tokenUrl=tokenUrl,
scopes=scopes,
refreshUrl=refreshUrl,
),
clientCredentials=OAuthFlowClientCredentials(
tokenUrl=tokenUrl,
scopes=scopes,
refreshUrl=refreshUrl,
),
)
super().__init__(
flows=flows,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)

async def __call__(self, request: Request) -> str | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return None
return param


# https://developer.zendesk.com/api-reference/sales-crm/authentication/requests/#client-authentication
def get_credentials_from_basic_auth(
request: Request,
) -> dict[str, str | int] | None:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "basic":
return None

data = base64.b64decode(param).decode("utf-8")
if ":" not in data:
return None

split = data.split(":")
if len(split) != 2:
return None
if not split[0].isdecimal():
return None

return {
"client_id": int(split[0]),
"client_secret": split[1],
}
28 changes: 28 additions & 0 deletions app/api/v2/models/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import annotations

from datetime import datetime
from enum import StrEnum
from typing import Literal

from . import BaseModel

# input models


# output models


class GrantType(StrEnum):
AUTHORIZATION_CODE = "authorization_code"
CLIENT_CREDENTIALS = "client_credentials"

# TODO: Add support for other grant types


class Token(BaseModel):
access_token: str
refresh_token: str | None
token_type: Literal["Bearer"]
expires_in: int
expires_at: datetime
scope: str
Loading
Loading