-
Notifications
You must be signed in to change notification settings - Fork 8.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Yuanbo Li <[email protected]>
- Loading branch information
1 parent
e576d32
commit cc0b92b
Showing
6 changed files
with
668 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import json | ||
from typing import Any, Union | ||
|
||
import boto3 | ||
|
||
from core.tools.entities.tool_entities import ToolInvokeMessage | ||
from core.tools.tool.builtin_tool import BuiltinTool | ||
|
||
# 定义标签映射 | ||
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"} | ||
|
||
|
||
class ContentModerationTool(BuiltinTool): | ||
sagemaker_client: Any = None | ||
sagemaker_endpoint: str = None | ||
|
||
def _invoke_sagemaker(self, payload: dict, endpoint: str): | ||
response = self.sagemaker_client.invoke_endpoint( | ||
EndpointName=endpoint, | ||
Body=json.dumps(payload), | ||
ContentType="application/json", | ||
) | ||
# Parse response | ||
response_body = response["Body"].read().decode("utf8") | ||
|
||
json_obj = json.loads(response_body) | ||
|
||
# Handle nested JSON if present | ||
if isinstance(json_obj, dict) and "body" in json_obj: | ||
body_content = json.loads(json_obj["body"]) | ||
raw_label = body_content.get("label") | ||
else: | ||
raw_label = json_obj.get("label") | ||
|
||
# 映射标签并返回 | ||
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE | ||
return result | ||
|
||
def _invoke( | ||
self, | ||
user_id: str, | ||
tool_parameters: dict[str, Any], | ||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: | ||
""" | ||
invoke tools | ||
""" | ||
try: | ||
if not self.sagemaker_client: | ||
aws_region = tool_parameters.get("aws_region") | ||
if aws_region: | ||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) | ||
else: | ||
self.sagemaker_client = boto3.client("sagemaker-runtime") | ||
|
||
if not self.sagemaker_endpoint: | ||
self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") | ||
|
||
content_text = tool_parameters.get("content_text") | ||
|
||
payload = {"text": content_text} | ||
|
||
result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) | ||
|
||
return self.create_text_message(text=result) | ||
|
||
except Exception as e: | ||
return self.create_text_message(f"Exception {str(e)}") |
46 changes: 46 additions & 0 deletions
46
api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
identity: | ||
name: chinese_toxicity_detector | ||
author: AWS | ||
label: | ||
en_US: Chinese Toxicity Detector | ||
zh_Hans: 中文有害内容检测 | ||
icon: icon.svg | ||
description: | ||
human: | ||
en_US: A tool to detect Chinese toxicity | ||
zh_Hans: 检测中文有害内容的工具 | ||
llm: A tool that checks if Chinese content is safe for work | ||
parameters: | ||
- name: sagemaker_endpoint | ||
type: string | ||
required: true | ||
label: | ||
en_US: sagemaker endpoint for moderation | ||
zh_Hans: 内容审核的SageMaker端点 | ||
human_description: | ||
en_US: sagemaker endpoint for content moderation | ||
zh_Hans: 内容审核的SageMaker端点 | ||
llm_description: sagemaker endpoint for content moderation | ||
form: form | ||
- name: content_text | ||
type: string | ||
required: true | ||
label: | ||
en_US: content text | ||
zh_Hans: 待审核文本 | ||
human_description: | ||
en_US: text content to be moderated | ||
zh_Hans: 需要审核的文本内容 | ||
llm_description: text content to be moderated | ||
form: llm | ||
- name: aws_region | ||
type: string | ||
required: false | ||
label: | ||
en_US: region of sagemaker endpoint | ||
zh_Hans: SageMaker 端点所在的region | ||
human_description: | ||
en_US: region of sagemaker endpoint | ||
zh_Hans: SageMaker 端点所在的region | ||
llm_description: region of sagemaker endpoint | ||
form: form |
Oops, something went wrong.