-
Notifications
You must be signed in to change notification settings - Fork 8.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Yuanbo Li <[email protected]>
- Loading branch information
1 parent
7a00798
commit de8800f
Showing
8 changed files
with
1,407 additions
and
0 deletions.
There are no files selected for viewing
115 changes: 115 additions & 0 deletions
115
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import json | ||
import operator | ||
from typing import Any, Optional, Union | ||
|
||
import boto3 | ||
|
||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
|
||
class BedrockRetrieveTool(BuiltinTool): | ||
bedrock_client: Any = None | ||
knowledge_base_id: str = None | ||
topk: int = None | ||
|
||
def _bedrock_retrieve( | ||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None | ||
): | ||
try: | ||
retrieval_query = {"text": query_input} | ||
|
||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} | ||
|
||
# 如果有元数据过滤条件,则添加到检索配置中 | ||
if metadata_filter: | ||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter | ||
|
||
response = self.bedrock_client.retrieve( | ||
knowledgeBaseId=knowledge_base_id, | ||
retrievalQuery=retrieval_query, | ||
retrievalConfiguration=retrieval_configuration, | ||
) | ||
|
||
results = [] | ||
for result in response.get("retrievalResults", []): | ||
results.append( | ||
{ | ||
"content": result.get("content", {}).get("text", ""), | ||
"score": result.get("score", 0.0), | ||
"metadata": result.get("metadata", {}), | ||
} | ||
) | ||
|
||
return results | ||
except Exception as e: | ||
raise Exception(f"Error retrieving from knowledge base: {str(e)}") | ||
|
||
def _invoke( | ||
self, | ||
user_id: str, | ||
tool_parameters: dict[str, Any], | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
""" | ||
invoke tools | ||
""" | ||
line = 0 | ||
try: | ||
if not self.bedrock_client: | ||
aws_region = tool_parameters.get("aws_region") | ||
if aws_region: | ||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region) | ||
else: | ||
self.bedrock_client = boto3.client("bedrock-agent-runtime") | ||
|
||
line = 1 | ||
if not self.knowledge_base_id: | ||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id") | ||
if not self.knowledge_base_id: | ||
return self.create_text_message("Please provide knowledge_base_id") | ||
|
||
line = 2 | ||
if not self.topk: | ||
self.topk = tool_parameters.get("topk", 5) | ||
|
||
line = 3 | ||
query = tool_parameters.get("query", "") | ||
if not query: | ||
return self.create_text_message("Please input query") | ||
|
||
# 获取元数据过滤条件(如果存在) | ||
metadata_filter_str = tool_parameters.get("metadata_filter") | ||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None | ||
|
||
line = 4 | ||
retrieved_docs = self._bedrock_retrieve( | ||
query_input=query, | ||
knowledge_base_id=self.knowledge_base_id, | ||
num_results=self.topk, | ||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 | ||
) | ||
|
||
line = 5 | ||
# Sort results by score in descending order | ||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True) | ||
|
||
line = 6 | ||
return [self.create_json_message(res) for res in sorted_docs] | ||
|
||
except Exception as e: | ||
return self.create_text_message(f"Exception {str(e)}, line : {line}") | ||
|
||
def validate_parameters(self, parameters: dict[str, Any]) -> None: | ||
""" | ||
Validate the parameters | ||
""" | ||
if not parameters.get("knowledge_base_id"): | ||
raise ValueError("knowledge_base_id is required") | ||
|
||
if not parameters.get("query"): | ||
raise ValueError("query is required") | ||
|
||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) | ||
metadata_filter_str = parameters.get("metadata_filter") | ||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): | ||
raise ValueError("metadata_filter must be a valid JSON object") |
87 changes: 87 additions & 0 deletions
87
api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
identity: | ||
name: bedrock_retrieve | ||
author: AWS | ||
label: | ||
en_US: Bedrock Retrieve | ||
zh_Hans: Bedrock检索 | ||
pt_BR: Bedrock Retrieve | ||
icon: icon.svg | ||
|
||
description: | ||
human: | ||
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool | ||
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明 | ||
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. | ||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool | ||
|
||
parameters: | ||
- name: knowledge_base_id | ||
type: string | ||
required: true | ||
label: | ||
en_US: Bedrock Knowledge Base ID | ||
zh_Hans: Bedrock知识库ID | ||
pt_BR: Bedrock Knowledge Base ID | ||
human_description: | ||
en_US: ID of the Bedrock Knowledge Base to retrieve from | ||
zh_Hans: 用于检索的Bedrock知识库ID | ||
pt_BR: ID of the Bedrock Knowledge Base to retrieve from | ||
llm_description: ID of the Bedrock Knowledge Base to retrieve from | ||
form: form | ||
|
||
- name: query | ||
type: string | ||
required: true | ||
label: | ||
en_US: Query string | ||
zh_Hans: 查询语句 | ||
pt_BR: Query string | ||
human_description: | ||
en_US: The search query to retrieve relevant information | ||
zh_Hans: 用于检索相关信息的查询语句 | ||
pt_BR: The search query to retrieve relevant information | ||
llm_description: The search query to retrieve relevant information | ||
form: llm | ||
|
||
- name: topk | ||
type: number | ||
required: false | ||
form: form | ||
label: | ||
en_US: Limit for results count | ||
zh_Hans: 返回结果数量限制 | ||
pt_BR: Limit for results count | ||
human_description: | ||
en_US: Maximum number of results to return | ||
zh_Hans: 最大返回结果数量 | ||
pt_BR: Maximum number of results to return | ||
min: 1 | ||
max: 10 | ||
default: 5 | ||
|
||
- name: aws_region | ||
type: string | ||
required: false | ||
label: | ||
en_US: AWS Region | ||
zh_Hans: AWS 区域 | ||
pt_BR: AWS Region | ||
human_description: | ||
en_US: AWS region where the Bedrock Knowledge Base is located | ||
zh_Hans: Bedrock知识库所在的AWS区域 | ||
pt_BR: AWS region where the Bedrock Knowledge Base is located | ||
llm_description: AWS region where the Bedrock Knowledge Base is located | ||
form: form | ||
|
||
- name: metadata_filter | ||
type: string | ||
required: false | ||
label: | ||
en_US: Metadata Filter | ||
zh_Hans: 元数据过滤器 | ||
pt_BR: Metadata Filter | ||
human_description: | ||
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' | ||
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})' | ||
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' | ||
form: form |
Oops, something went wrong.