Skip to content

Commit

Permalink
Merge branch 'refs/heads/feat/new-login' into deploy/dev
Browse files Browse the repository at this point in the history
* refs/heads/feat/new-login:
  feat: remove redict signin
  fix: oauth AccountNotFound

# Conflicts:
#	api/controllers/console/auth/oauth.py
  • Loading branch information
ZhouhaoJiang committed Sep 2, 2024
2 parents 53d83c4 + 910f9b7 commit 3223aac
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions api/controllers/console/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models.account import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountNotFound

from .. import api

Expand Down Expand Up @@ -44,14 +43,15 @@ def get_oauth_providers():

class OAuthLogin(Resource):
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
print(vars(oauth_provider))
if not oauth_provider:
return {"error": "Invalid provider"}, 400

auth_url = oauth_provider.get_authorization_url()
auth_url = oauth_provider.get_authorization_url(invite_token)
return redirect(auth_url)


Expand All @@ -64,13 +64,21 @@ def get(self, provider: str):
return {"error": "Invalid provider"}, 400

code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state

try:
token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token)
except requests.exceptions.HTTPError as e:
logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}")
return {"error": "OAuth process failed"}, 400

if invite_token:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/invite-settings?invite_token={invite_token}")

try:
account = _generate_account(provider, user_info)
except services.errors.account.AccountNotFound as e:
Expand Down Expand Up @@ -104,7 +112,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)

if not account and dify_config.ALLOW_REGISTER:
if not account:
account_name = user_info.name if user_info.name else "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
Expand All @@ -118,8 +126,6 @@ def _generate_account(provider: str, user_info: OAuthUserInfo):
interface_language = languages[0]
account.interface_language = interface_language
db.session.commit()
else:
raise AccountNotFound()

# Link account
AccountService.link_account_integrate(provider, user_info.id, account)
Expand Down

0 comments on commit 3223aac

Please sign in to comment.