Skip to content

Commit

Permalink
refactor: update builtin tool provider methods to use session managem…
Browse files Browse the repository at this point in the history
…ent (#11938)

Signed-off-by: -LAN- <[email protected]>
  • Loading branch information
laipz8200 authored Dec 21, 2024
1 parent 8f73670 commit 606aadb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
24 changes: 14 additions & 10 deletions api/controllers/console/workspace/tool_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden

from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService
Expand Down Expand Up @@ -91,26 +93,28 @@ def post(self, provider):

args = parser.parse_args()

return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
args["credentials"],
)
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
return result


class ToolBuiltinProviderGetCredentialsApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id

return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
tenant_id=tenant_id,
provider_name=provider,
)


Expand Down
33 changes: 16 additions & 17 deletions api/services/tools/builtin_tools_manage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import logging
from pathlib import Path

from sqlalchemy import select
from sqlalchemy.orm import Session

from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
Expand Down Expand Up @@ -32,7 +35,7 @@ def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str
tenant_id=tenant_id, provider_controller=provider_controller
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = (
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
Expand Down Expand Up @@ -71,19 +74,18 @@ def list_builtin_provider_credentials_schema(provider_name):
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])

@staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
"""
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
provider = session.scalar(stmt)

try:
# get provider
Expand Down Expand Up @@ -115,29 +117,26 @@ def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st
encrypted_credentials=json.dumps(credentials),
)

db.session.add(provider)
db.session.commit()
session.add(provider)

else:
provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider)
db.session.commit()

# delete cache
tool_configuration.delete_tool_credentials_cache()

return {"result": "success"}

@staticmethod
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = (
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
Expand All @@ -156,7 +155,7 @@ def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: st
"""
delete tool provider
"""
provider: BuiltinToolProvider = (
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
Expand Down

0 comments on commit 606aadb

Please sign in to comment.