Skip to content

Commit

Permalink
fix: buitin tool aippt (#10234)
Browse files Browse the repository at this point in the history
Co-authored-by: jinqi.guo <[email protected]>
  • Loading branch information
guogeer and jinqi.guo authored Nov 4, 2024
1 parent 6b0de08 commit 971defb
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
78 changes: 49 additions & 29 deletions api/core/tools/provider/builtin/aippt/tools/aippt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from json import loads as json_loads
from threading import Lock
from time import sleep, time
from typing import Any, Optional
from typing import Any

from httpx import get, post
from requests import get as requests_get
Expand All @@ -15,27 +15,27 @@
from core.tools.tool.builtin_tool import BuiltinTool


class AIPPTGenerateTool(BuiltinTool):
class AIPPTGenerateToolAdapter:
"""
A tool for generating a ppt
"""

_api_base_url = URL("https://co.aippt.cn/api")
_api_token_cache = {}
_api_token_cache_lock: Optional[Lock] = None
_style_cache = {}
_style_cache_lock: Optional[Lock] = None

_api_token_cache_lock = Lock()
_style_cache_lock = Lock()

_task = {}
_task_type_map = {
"auto": 1,
"markdown": 7,
}
_tool: BuiltinTool

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._api_token_cache_lock = Lock()
self._style_cache_lock = Lock()
def __init__(self, tool: BuiltinTool = None):
self._tool = tool

def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
"""
Expand All @@ -51,11 +51,11 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
"""
title = tool_parameters.get("title", "")
if not title:
return self.create_text_message("Please provide a title for the ppt")
return self._tool.create_text_message("Please provide a title for the ppt")

model = tool_parameters.get("model", "aippt")
if not model:
return self.create_text_message("Please provide a model for the ppt")
return self._tool.create_text_message("Please provide a model for the ppt")

outline = tool_parameters.get("outline", "")

Expand All @@ -68,8 +68,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
)

# get suit
color = tool_parameters.get("color")
style = tool_parameters.get("style")
color: str = tool_parameters.get("color")
style: str = tool_parameters.get("style")

if color == "__default__":
color_id = ""
Expand All @@ -93,9 +93,9 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe
# generate ppt
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)

return self.create_text_message(
return self._tool.create_text_message(
"""the ppt has been created successfully,"""
f"""the ppt url is {ppt_url}"""
f"""the ppt url is {ppt_url} ."""
"""please give the ppt url to user and direct user to download it."""
)

Expand All @@ -111,8 +111,8 @@ def _create_task(self, type: int, title: str, content: str, user_id: str) -> str
"""
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}
response = post(
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
Expand All @@ -139,8 +139,8 @@ def _generate_outline(self, task_id: str, model: str, user_id: str) -> str:

headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}

response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
Expand Down Expand Up @@ -183,8 +183,8 @@ def _generate_content(self, task_id: str, model: str, user_id: str) -> str:

headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}

response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
Expand Down Expand Up @@ -236,14 +236,15 @@ def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]:
"""
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
}

response = post(
str(self._api_base_url / "design" / "v2" / "save"),
headers=headers,
data={"task_id": task_id, "template_id": suit_id},
timeout=(10, 60),
)

if response.status_code != 200:
Expand Down Expand Up @@ -350,11 +351,13 @@ def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:

return token

@classmethod
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
@staticmethod
def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
return b64encode(
hmac_new(
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
key=secret_key.encode("utf-8"),
msg=f"GET@/api/grant/token/@{timestamp}".encode(),
digestmod=sha1,
).digest()
).decode("utf-8")

Expand Down Expand Up @@ -419,19 +422,21 @@ def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get(
"aippt_secret_key"
):
raise Exception("Please provide aippt credentials")

return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id)

def _get_suit(self, style_id: int, colour_id: int) -> int:
"""
Get suit
"""
headers = {
"x-channel": "",
"x-api-key": self.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"),
}
response = get(
str(self._api_base_url / "template_component" / "suit" / "search"),
Expand Down Expand Up @@ -496,3 +501,18 @@ def get_runtime_parameters(self) -> list[ToolParameter]:
],
),
]


class AIPPTGenerateTool(BuiltinTool):
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)

def get_runtime_parameters(self) -> list[ToolParameter]:
return AIPPTGenerateToolAdapter(self).get_runtime_parameters()

@classmethod
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id)
2 changes: 1 addition & 1 deletion api/core/workflow/nodes/tool/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _run(self) -> NodeRunResult:
)

# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
tool_parameters = tool_runtime.parameters or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
Expand Down

0 comments on commit 971defb

Please sign in to comment.