Skip to content

Commit

Permalink
Convert JWTStore to use authlib (#4722)
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator authored Sep 24, 2024
1 parent 7a41f5a commit a5ed26f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions integration_tests/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ openinference-semantic-conventions
opentelemetry-sdk
portpicker
psutil
pyjwt
pytest-randomly
pytest-smtpd
types-beautifulsoup4
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ dependencies = [
"fastapi",
"fastapi-mail",
"pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting
"pyjwt",
"authlib",
]
dynamic = ["version"]
Expand Down
15 changes: 8 additions & 7 deletions src/phoenix/server/jwt_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from functools import cached_property, singledispatchmethod
from typing import Any, Callable, Coroutine, Dict, Generic, List, Optional, Tuple, Type, TypeVar

import jwt
from authlib.jose import jwt
from authlib.jose.errors import BadSignatureError
from sqlalchemy import Select, delete, select

from phoenix.auth import (
Expand Down Expand Up @@ -75,12 +76,10 @@ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
async def read(self, token: Token) -> Optional[ClaimSet]:
try:
payload = jwt.decode(
token,
self._secret,
algorithms=[JWT_ALGORITHM],
options={"verify_exp": False},
s=token,
key=self._secret,
)
except jwt.DecodeError:
except BadSignatureError:
return None
if (jti := payload.get("jti")) is None:
return None
Expand Down Expand Up @@ -247,7 +246,9 @@ def _encode(self, claim: ClaimSet) -> str:
payload: Dict[str, Any] = dict(jti=claim.token_id)
if claim.expiration_time:
payload["exp"] = int(claim.expiration_time.timestamp())
return jwt.encode(payload, self._secret, algorithm=self._algorithm)
header = {"alg": self._algorithm}
jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=self._secret)
return jwt_bytes.decode()

async def get(self, token_id: _TokenIdT) -> Optional[_ClaimSetT]:
if claims := self._claims.get(token_id):
Expand Down

0 comments on commit a5ed26f

Please sign in to comment.