diff --git a/sdk/python/packages/flet-runtime/src/flet_runtime/auth/providers/__init__.py b/sdk/python/packages/flet-runtime/src/flet_runtime/auth/providers/__init__.py new file mode 100644 index 000000000..fa12f3654 --- /dev/null +++ b/sdk/python/packages/flet-runtime/src/flet_runtime/auth/providers/__init__.py @@ -0,0 +1,4 @@ +from flet_runtime.auth.providers.auth0_oauth_provider import Auth0OAuthProvider +from flet_runtime.auth.providers.azure_oauth_provider import AzureOAuthProvider +from flet_runtime.auth.providers.github_oauth_provider import GitHubOAuthProvider +from flet_runtime.auth.providers.google_oauth_provider import GoogleOAuthProvider diff --git a/sdk/python/packages/flet/src/flet/auth/authorization.py b/sdk/python/packages/flet/src/flet/auth/authorization.py index 2034060c9..999cec0a2 100644 --- a/sdk/python/packages/flet/src/flet/auth/authorization.py +++ b/sdk/python/packages/flet/src/flet/auth/authorization.py @@ -7,7 +7,7 @@ import httpx from flet.auth.oauth_provider import OAuthProvider -from flet.auth.oauth_token import OAuthToken +from flet.auth.oauth_token import OAuthToken, WeChatOAuthToken from flet.auth.user import User from flet.version import version from flet_core.locks import AsyncNopeLock, NopeLock @@ -78,6 +78,7 @@ def get_authorization_data(self) -> Tuple[str, str]: state=self.state, code_challenge=self.provider.code_challenge, code_challenge_method=self.provider.code_challenge_method, + appid=self.provider.client_id, ) return authorization_url, self.state @@ -103,6 +104,17 @@ async def request_token_async(self, code: str): def __get_request_token_request(self, code: str): client = WebApplicationClient(self.provider.client_id) + headers = self.__get_default_headers() + if self.is_wechat_oauth_provider(): + data = client.prepare_request_body( + secret=self.provider.client_secret, + code=code, + appid=self.provider.client_id, + include_client_id=False, + ) + return httpx.Request( + "GET", self.provider.token_endpoint, params=data, headers=headers + ) data = client.prepare_request_body( code=code, redirect_uri=self.provider.redirect_url, @@ -110,7 +122,6 @@ def __get_request_token_request(self, code: str): include_client_id=True, code_verifier=self.provider.code_verifier, ) - headers = self.__get_default_headers() headers["content-type"] = "application/x-www-form-urlencoded" return httpx.Request( "POST", self.provider.token_endpoint, content=data, headers=headers @@ -119,7 +130,12 @@ def __get_request_token_request(self, code: str): def __fetch_user_and_groups(self): assert self.__token is not None if self.fetch_user: - self.user = self.provider._fetch_user(self.__token.access_token) + if self.is_wechat_oauth_provider(): + self.user = self.provider._fetch_user( + self.__token.access_token, self.__token.openid + ) + else: + self.user = self.provider._fetch_user(self.__token.access_token) if self.user is None and self.provider.user_endpoint is not None: if self.provider.user_id_fn is None: raise Exception( @@ -134,7 +150,14 @@ def __fetch_user_and_groups(self): async def __fetch_user_and_groups_async(self): assert self.__token is not None if self.fetch_user: - self.user = await self.provider._fetch_user_async(self.__token.access_token) + if self.is_wechat_oauth_provider(): + self.user = await self.provider._fetch_user_async( + self.__token.access_token, self.__token.openid + ) + else: + self.user = await self.provider._fetch_user_async( + self.__token.access_token + ) if self.user is None and self.provider.user_endpoint is not None: if self.provider.user_id_fn is None: raise Exception( @@ -146,7 +169,24 @@ async def __fetch_user_and_groups_async(self): self.__token.access_token ) + def is_wechat_oauth_provider(self): + return ( + self.provider.token_endpoint + == "https://api.weixin.qq.com/sns/oauth2/access_token" + ) + def __convert_token(self, t: OAuth2Token): + if self.is_wechat_oauth_provider(): + return WeChatOAuthToken( + access_token=t["access_token"], + scope=t.get("scope"), + token_type=t.get("token_type"), + expires_in=t.get("expires_in"), + expires_at=t.get("expires_at"), + refresh_token=t.get("refresh_token"), + openid=t.get("openid"), + unionid=t.get("unionid"), + ) return OAuthToken( access_token=t["access_token"], scope=t.get("scope"), diff --git a/sdk/python/packages/flet/src/flet/auth/oauth_token.py b/sdk/python/packages/flet/src/flet/auth/oauth_token.py index 0156e2d8b..514b518fb 100644 --- a/sdk/python/packages/flet/src/flet/auth/oauth_token.py +++ b/sdk/python/packages/flet/src/flet/auth/oauth_token.py @@ -28,3 +28,27 @@ def to_json(self): def from_json(data: str): t = json.loads(data) return OAuthToken(**t) + + +class WeChatOAuthToken(OAuthToken): + def __init__( + self, + access_token: str, + scope: Optional[List[str]] = None, + token_type: Optional[str] = None, + expires_in: Optional[int] = None, + expires_at: Optional[float] = None, + refresh_token: Optional[str] = None, + openid: Optional[str] = None, + unionid: Optional[str] = None, + ) -> None: + super().__init__( + access_token=access_token, + scope=scope, + token_type=token_type, + expires_in=expires_in, + expires_at=expires_at, + refresh_token=refresh_token, + ) + self.openid = openid + self.unionid = unionid diff --git a/sdk/python/packages/flet/src/flet/auth/providers/__init__.py b/sdk/python/packages/flet/src/flet/auth/providers/__init__.py index 0d434e441..9242098cf 100644 --- a/sdk/python/packages/flet/src/flet/auth/providers/__init__.py +++ b/sdk/python/packages/flet/src/flet/auth/providers/__init__.py @@ -2,3 +2,4 @@ from flet.auth.providers.azure_oauth_provider import AzureOAuthProvider from flet.auth.providers.github_oauth_provider import GitHubOAuthProvider from flet.auth.providers.google_oauth_provider import GoogleOAuthProvider +from flet.auth.providers.wechat_oauth_provider import WeChatOAuthProvider diff --git a/sdk/python/packages/flet/src/flet/auth/providers/wechat_oauth_provider.py b/sdk/python/packages/flet/src/flet/auth/providers/wechat_oauth_provider.py new file mode 100644 index 000000000..91ef2245b --- /dev/null +++ b/sdk/python/packages/flet/src/flet/auth/providers/wechat_oauth_provider.py @@ -0,0 +1,68 @@ +from typing import List, Optional + +import httpx +from flet_runtime.auth.oauth_provider import OAuthProvider +from flet_runtime.auth.user import User +from flet_runtime.version import version + + +class WeChatOAuthProvider(OAuthProvider): + """ + OAuth provider for WeChat authentication. + + WeChat's OAuth flow differs from standard implementations: + - Uses a unique 'code' parameter instead of typical 'access_token' + - Requires additional steps for user info retrieval + - Implements state parameter differently for security + """ + + def __init__( + self, + client_id: str, + client_secret: str, + redirect_url: str, + scopes: Optional[List[str]] = ["snsapi_login"], + ) -> None: + super().__init__( + client_id=client_id, + client_secret=client_secret, + authorization_endpoint="https://open.weixin.qq.com/connect/qrconnect", + token_endpoint="https://api.weixin.qq.com/sns/oauth2/access_token", + user_endpoint="https://api.weixin.qq.com/sns/userinfo", + redirect_url=redirect_url, + scopes=scopes, + ) + + def _fetch_user(self, access_token: str, openid: str) -> Optional[User]: + user_req = self.__get_user_details_requests(access_token, openid) + with httpx.Client(follow_redirects=True) as client: + user_resp = client.send(user_req) + return self.__complete_fetch_user_details(user_resp) + + async def _fetch_user_async(self, access_token: str, openid: str) -> Optional[User]: + user_req = self.__get_user_details_requests(access_token, openid) + async with httpx.AsyncClient() as client: + user_resp = await client.send(user_req) + return self.__complete_fetch_user_details(user_resp) + + def __get_user_details_requests(self, access_token, openid): + params = { + "access_token": access_token, + "openid": openid, + } + return httpx.Request( + "GET", + self.user_endpoint, + params=params, + headers=self.__get_client_headers(), + ) + + def __complete_fetch_user_details(self, user_resp): + user_resp.raise_for_status() + uj = user_resp.json() + return User(uj, id=str(uj["openid"])) + + def __get_client_headers(self): + return { + "User-Agent": f"Flet/{version}", + }