Skip to content

Commit

Permalink
add three aws tools (#11905)
Browse files Browse the repository at this point in the history
Co-authored-by: Yuanbo Li <[email protected]>
  • Loading branch information
ybalbert001 and Yuanbo Li authored Dec 21, 2024
1 parent 7a00798 commit de8800f
Show file tree
Hide file tree
Showing 8 changed files with 1,407 additions and 0 deletions.
115 changes: 115 additions & 0 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import operator
from typing import Any, Optional, Union

import boto3

from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool


class BedrockRetrieveTool(BuiltinTool):
bedrock_client: Any = None
knowledge_base_id: str = None
topk: int = None

def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
):
try:
retrieval_query = {"text": query_input}

retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}

# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter

response = self.bedrock_client.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery=retrieval_query,
retrievalConfiguration=retrieval_configuration,
)

results = []
for result in response.get("retrievalResults", []):
results.append(
{
"content": result.get("content", {}).get("text", ""),
"score": result.get("score", 0.0),
"metadata": result.get("metadata", {}),
}
)

return results
except Exception as e:
raise Exception(f"Error retrieving from knowledge base: {str(e)}")

def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
line = 0
try:
if not self.bedrock_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
else:
self.bedrock_client = boto3.client("bedrock-agent-runtime")

line = 1
if not self.knowledge_base_id:
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
if not self.knowledge_base_id:
return self.create_text_message("Please provide knowledge_base_id")

line = 2
if not self.topk:
self.topk = tool_parameters.get("topk", 5)

line = 3
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message("Please input query")

# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None

line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
)

line = 5
# Sort results by score in descending order
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)

line = 6
return [self.create_json_message(res) for res in sorted_docs]

except Exception as e:
return self.create_text_message(f"Exception {str(e)}, line : {line}")

def validate_parameters(self, parameters: dict[str, Any]) -> None:
"""
Validate the parameters
"""
if not parameters.get("knowledge_base_id"):
raise ValueError("knowledge_base_id is required")

if not parameters.get("query"):
raise ValueError("query is required")

# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")
87 changes: 87 additions & 0 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
identity:
name: bedrock_retrieve
author: AWS
label:
en_US: Bedrock Retrieve
zh_Hans: Bedrock检索
pt_BR: Bedrock Retrieve
icon: icon.svg

description:
human:
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool

parameters:
- name: knowledge_base_id
type: string
required: true
label:
en_US: Bedrock Knowledge Base ID
zh_Hans: Bedrock知识库ID
pt_BR: Bedrock Knowledge Base ID
human_description:
en_US: ID of the Bedrock Knowledge Base to retrieve from
zh_Hans: 用于检索的Bedrock知识库ID
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
llm_description: ID of the Bedrock Knowledge Base to retrieve from
form: form

- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
pt_BR: Query string
human_description:
en_US: The search query to retrieve relevant information
zh_Hans: 用于检索相关信息的查询语句
pt_BR: The search query to retrieve relevant information
llm_description: The search query to retrieve relevant information
form: llm

- name: topk
type: number
required: false
form: form
label:
en_US: Limit for results count
zh_Hans: 返回结果数量限制
pt_BR: Limit for results count
human_description:
en_US: Maximum number of results to return
zh_Hans: 最大返回结果数量
pt_BR: Maximum number of results to return
min: 1
max: 10
default: 5

- name: aws_region
type: string
required: false
label:
en_US: AWS Region
zh_Hans: AWS 区域
pt_BR: AWS Region
human_description:
en_US: AWS region where the Bedrock Knowledge Base is located
zh_Hans: Bedrock知识库所在的AWS区域
pt_BR: AWS region where the Bedrock Knowledge Base is located
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form

- name: metadata_filter
type: string
required: false
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器
pt_BR: Metadata Filter
human_description:
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})'
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
form: form
Loading

0 comments on commit de8800f

Please sign in to comment.