diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 48755753ace7c1..989608122185c8 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -12,7 +12,7 @@ class LambdaTranslateUtilsTool(BuiltinTool): def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): msg = { - "src_content": text_content, + "src_contents": [text_content], "src_lang": src_lang, "dest_lang": dest_lang, "dictionary_id": dictionary_name, diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml index 3bb133c7ec8d16..646602fcd6c245 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml @@ -8,9 +8,9 @@ identity: icon: icon.svg description: human: - en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag - zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag - pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag + en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock + zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock + pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock llm: A util tools for translation. parameters: - name: text_content diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py new file mode 100644 index 00000000000000..e05e2d9bf7d356 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py @@ -0,0 +1,67 @@ +import json +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +# 定义标签映射 +LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"} + + +class ContentModerationTool(BuiltinTool): + sagemaker_client: Any = None + sagemaker_endpoint: str = None + + def _invoke_sagemaker(self, payload: dict, endpoint: str): + response = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + # Parse response + response_body = response["Body"].read().decode("utf8") + + json_obj = json.loads(response_body) + + # Handle nested JSON if present + if isinstance(json_obj, dict) and "body" in json_obj: + body_content = json.loads(json_obj["body"]) + raw_label = body_content.get("label") + else: + raw_label = json_obj.get("label") + + # 映射标签并返回 + result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE + return result + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.sagemaker_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + if not self.sagemaker_endpoint: + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") + + content_text = tool_parameters.get("content_text") + + payload = {"text": content_text} + + result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) + + return self.create_text_message(text=result) + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml new file mode 100644 index 00000000000000..76dcb89632f270 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml @@ -0,0 +1,46 @@ +identity: + name: chinese_toxicity_detector + author: AWS + label: + en_US: Chinese Toxicity Detector + zh_Hans: 中文有害内容检测 + icon: icon.svg +description: + human: + en_US: A tool to detect Chinese toxicity + zh_Hans: 检测中文有害内容的工具 + llm: A tool that checks if Chinese content is safe for work +parameters: + - name: sagemaker_endpoint + type: string + required: true + label: + en_US: sagemaker endpoint for moderation + zh_Hans: 内容审核的SageMaker端点 + human_description: + en_US: sagemaker endpoint for content moderation + zh_Hans: 内容审核的SageMaker端点 + llm_description: sagemaker endpoint for content moderation + form: form + - name: content_text + type: string + required: true + label: + en_US: content text + zh_Hans: 待审核文本 + human_description: + en_US: text content to be moderated + zh_Hans: 需要审核的文本内容 + llm_description: text content to be moderated + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + human_description: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + llm_description: region of sagemaker endpoint + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/transcribe_asr.py b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.py new file mode 100644 index 00000000000000..7520f6bca8b1ce --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.py @@ -0,0 +1,418 @@ +import json +import logging +import os +import re +import time +import uuid +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 +import requests +from botocore.exceptions import ClientError +from requests.exceptions import RequestException + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +LanguageCodeOptions = [ + "af-ZA", + "ar-AE", + "ar-SA", + "da-DK", + "de-CH", + "de-DE", + "en-AB", + "en-AU", + "en-GB", + "en-IE", + "en-IN", + "en-US", + "en-WL", + "es-ES", + "es-US", + "fa-IR", + "fr-CA", + "fr-FR", + "he-IL", + "hi-IN", + "id-ID", + "it-IT", + "ja-JP", + "ko-KR", + "ms-MY", + "nl-NL", + "pt-BR", + "pt-PT", + "ru-RU", + "ta-IN", + "te-IN", + "tr-TR", + "zh-CN", + "zh-TW", + "th-TH", + "en-ZA", + "en-NZ", + "vi-VN", + "sv-SE", + "ab-GE", + "ast-ES", + "az-AZ", + "ba-RU", + "be-BY", + "bg-BG", + "bn-IN", + "bs-BA", + "ca-ES", + "ckb-IQ", + "ckb-IR", + "cs-CZ", + "cy-WL", + "el-GR", + "et-ET", + "eu-ES", + "fi-FI", + "gl-ES", + "gu-IN", + "ha-NG", + "hr-HR", + "hu-HU", + "hy-AM", + "is-IS", + "ka-GE", + "kab-DZ", + "kk-KZ", + "kn-IN", + "ky-KG", + "lg-IN", + "lt-LT", + "lv-LV", + "mhr-RU", + "mi-NZ", + "mk-MK", + "ml-IN", + "mn-MN", + "mr-IN", + "mt-MT", + "no-NO", + "or-IN", + "pa-IN", + "pl-PL", + "ps-AF", + "ro-RO", + "rw-RW", + "si-LK", + "sk-SK", + "sl-SI", + "so-SO", + "sr-RS", + "su-ID", + "sw-BI", + "sw-KE", + "sw-RW", + "sw-TZ", + "sw-UG", + "tl-PH", + "tt-RU", + "ug-CN", + "uk-UA", + "uz-UZ", + "wo-SN", + "zu-ZA", +] + +MediaFormat = ["mp3", "mp4", "wav", "flac", "ogg", "amr", "webm", "m4a"] + + +def is_url(text): + if not text: + return False + text = text.strip() + # Regular expression pattern for URL validation + pattern = re.compile( + r"^" # Start of the string + r"(?:http|https)://" # Protocol (http or https) + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # Domain + r"localhost|" # localhost + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address + r"(?::\d+)?" # Optional port + r"(?:/?|[/?]\S+)" # Path + r"$", # End of the string + re.IGNORECASE, + ) + return bool(pattern.match(text)) + + +def upload_file_from_url_to_s3(s3_client, url, bucket_name, s3_key=None, max_retries=3): + """ + Upload a file from a URL to an S3 bucket with retries and better error handling. + + Parameters: + - s3_client + - url (str): The URL of the file to upload + - bucket_name (str): The name of the S3 bucket + - s3_key (str): The desired key (path) in S3. If None, will use the filename from URL + - max_retries (int): Maximum number of retry attempts + + Returns: + - tuple: (bool, str) - (Success status, Message) + """ + + # Validate inputs + if not url or not bucket_name: + return False, "URL and bucket name are required" + + retry_count = 0 + while retry_count < max_retries: + try: + # Download the file from URL + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + # If s3_key is not provided, try to get filename from URL + if not s3_key: + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path.split("/file-preview")[0]) + s3_key = "transcribe-files/" + filename + + # Upload the file to S3 + s3_client.upload_fileobj( + response.raw, + bucket_name, + s3_key, + ExtraArgs={ + "ContentType": response.headers.get("content-type"), + "ACL": "private", # Ensure the uploaded file is private + }, + ) + + return f"s3://{bucket_name}/{s3_key}", f"Successfully uploaded file to s3://{bucket_name}/{s3_key}" + + except RequestException as e: + retry_count += 1 + if retry_count == max_retries: + return None, f"Failed to download file from URL after {max_retries} attempts: {str(e)}" + continue + + except ClientError as e: + return None, f"AWS S3 error: {str(e)}" + + except Exception as e: + return None, f"Unexpected error: {str(e)}" + + return None, "Maximum retries exceeded" + + +class TranscribeTool(BuiltinTool): + s3_client: Any = None + transcribe_client: Any = None + + """ + Note that you must include one of LanguageCode, IdentifyLanguage, + or IdentifyMultipleLanguages in your request. + If you include more than one of these parameters, your transcription job fails. + """ + + def _transcribe_audio(self, audio_file_uri, file_type, **extra_args): + uuid_str = str(uuid.uuid4()) + job_name = f"{int(time.time())}-{uuid_str}" + try: + # Start transcription job + response = self.transcribe_client.start_transcription_job( + TranscriptionJobName=job_name, Media={"MediaFileUri": audio_file_uri}, **extra_args + ) + + # Wait for the job to complete + while True: + status = self.transcribe_client.get_transcription_job(TranscriptionJobName=job_name) + if status["TranscriptionJob"]["TranscriptionJobStatus"] in ["COMPLETED", "FAILED"]: + break + time.sleep(5) + + if status["TranscriptionJob"]["TranscriptionJobStatus"] == "COMPLETED": + return status["TranscriptionJob"]["Transcript"]["TranscriptFileUri"], None + else: + return None, f"Error: TranscriptionJobStatus:{status['TranscriptionJob']['TranscriptionJobStatus']} " + + except Exception as e: + return None, f"Error: {str(e)}" + + def _download_and_read_transcript(self, transcript_file_uri: str, max_retries: int = 3) -> tuple[str, str]: + """ + Download and read the transcript file from the given URI. + + Parameters: + - transcript_file_uri (str): The URI of the transcript file + - max_retries (int): Maximum number of retry attempts + + Returns: + - tuple: (text, error) - (Transcribed text if successful, error message if failed) + """ + retry_count = 0 + while retry_count < max_retries: + try: + # Download the transcript file + response = requests.get(transcript_file_uri, timeout=30) + response.raise_for_status() + + # Parse the JSON content + transcript_data = response.json() + + # Check if speaker labels are present and enabled + has_speaker_labels = ( + "results" in transcript_data + and "speaker_labels" in transcript_data["results"] + and "segments" in transcript_data["results"]["speaker_labels"] + ) + + if has_speaker_labels: + # Get speaker segments + segments = transcript_data["results"]["speaker_labels"]["segments"] + items = transcript_data["results"]["items"] + + # Create a mapping of start_time -> speaker_label + time_to_speaker = {} + for segment in segments: + speaker_label = segment["speaker_label"] + for item in segment["items"]: + time_to_speaker[item["start_time"]] = speaker_label + + # Build transcript with speaker labels + current_speaker = None + transcript_parts = [] + + for item in items: + # Skip non-pronunciation items (like punctuation) + if item["type"] == "punctuation": + transcript_parts.append(item["alternatives"][0]["content"]) + continue + + start_time = item["start_time"] + speaker = time_to_speaker.get(start_time) + + if speaker != current_speaker: + current_speaker = speaker + transcript_parts.append(f"\n[{speaker}]: ") + + transcript_parts.append(item["alternatives"][0]["content"]) + + return " ".join(transcript_parts).strip(), None + else: + # Extract the transcription text + # The transcript text is typically in the 'results' -> 'transcripts' array + if "results" in transcript_data and "transcripts" in transcript_data["results"]: + transcripts = transcript_data["results"]["transcripts"] + if transcripts: + # Combine all transcript segments + full_text = " ".join(t.get("transcript", "") for t in transcripts) + return full_text, None + + return None, "No transcripts found in the response" + + except requests.exceptions.RequestException as e: + retry_count += 1 + if retry_count == max_retries: + return None, f"Failed to download transcript file after {max_retries} attempts: {str(e)}" + continue + + except json.JSONDecodeError as e: + return None, f"Failed to parse transcript JSON: {str(e)}" + + except Exception as e: + return None, f"Unexpected error while processing transcript: {str(e)}" + + return None, "Maximum retries exceeded" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.transcribe_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.transcribe_client = boto3.client("transcribe", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.transcribe_client = boto3.client("transcribe") + self.s3_client = boto3.client("s3") + + file_url = tool_parameters.get("file_url") + file_type = tool_parameters.get("file_type") + language_code = tool_parameters.get("language_code") + identify_language = tool_parameters.get("identify_language", True) + identify_multiple_languages = tool_parameters.get("identify_multiple_languages", False) + language_options_str = tool_parameters.get("language_options") + s3_bucket_name = tool_parameters.get("s3_bucket_name") + ShowSpeakerLabels = tool_parameters.get("ShowSpeakerLabels", True) + MaxSpeakerLabels = tool_parameters.get("MaxSpeakerLabels", 2) + + # Check the input params + if not s3_bucket_name: + return self.create_text_message(text="s3_bucket_name is required") + language_options = None + if language_options_str: + language_options = language_options_str.split("|") + for lang in language_options: + if lang not in LanguageCodeOptions: + return self.create_text_message( + text=f"{lang} is not supported, should be one of {LanguageCodeOptions}" + ) + if language_code and language_code not in LanguageCodeOptions: + err_msg = f"language_code:{language_code} is not supported, should be one of {LanguageCodeOptions}" + return self.create_text_message(text=err_msg) + + err_msg = f"identify_language:{identify_language}, \ + identify_multiple_languages:{identify_multiple_languages}, \ + Note that you must include one of LanguageCode, IdentifyLanguage, \ + or IdentifyMultipleLanguages in your request. \ + If you include more than one of these parameters, \ + your transcription job fails." + if not language_code: + if identify_language and identify_multiple_languages: + return self.create_text_message(text=err_msg) + else: + if identify_language or identify_multiple_languages: + return self.create_text_message(text=err_msg) + + extra_args = { + "IdentifyLanguage": identify_language, + "IdentifyMultipleLanguages": identify_multiple_languages, + } + if language_code: + extra_args["LanguageCode"] = language_code + if language_options: + extra_args["LanguageOptions"] = language_options + if ShowSpeakerLabels: + extra_args["Settings"] = {"ShowSpeakerLabels": ShowSpeakerLabels, "MaxSpeakerLabels": MaxSpeakerLabels} + + # upload to s3 bucket + s3_path_result, error = upload_file_from_url_to_s3(self.s3_client, url=file_url, bucket_name=s3_bucket_name) + if not s3_path_result: + return self.create_text_message(text=error) + + transcript_file_uri, error = self._transcribe_audio( + audio_file_uri=s3_path_result, + file_type=file_type, + **extra_args, + ) + if not transcript_file_uri: + return self.create_text_message(text=error) + + # Download and read the transcript + transcript_text, error = self._download_and_read_transcript(transcript_file_uri) + if not transcript_text: + return self.create_text_message(text=error) + + return self.create_text_message(text=transcript_text) + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/transcribe_asr.yaml b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.yaml new file mode 100644 index 00000000000000..0dccd615d272dd --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.yaml @@ -0,0 +1,133 @@ +identity: + name: transcribe_asr + author: AWS + label: + en_US: TranscribeASR + zh_Hans: Transcribe语音识别转录 + pt_BR: TranscribeASR + icon: icon.svg +description: + human: + en_US: A tool for ASR (Automatic Speech Recognition) - https://github.com/aws-samples/dify-aws-tool + zh_Hans: AWS 语音识别转录服务, 请参考 https://aws.amazon.com/cn/pm/transcribe/#Learn_More_About_Amazon_Transcribe + pt_BR: A tool for ASR (Automatic Speech Recognition). + llm: A tool for ASR (Automatic Speech Recognition). +parameters: + - name: file_url + type: string + required: true + label: + en_US: video or audio file url for transcribe + zh_Hans: 语音或者视频文件url + pt_BR: video or audio file url for transcribe + human_description: + en_US: video or audio file url for transcribe + zh_Hans: 语音或者视频文件url + pt_BR: video or audio file url for transcribe + llm_description: video or audio file url for transcribe + form: llm + - name: language_code + type: string + required: false + label: + en_US: Language Code + zh_Hans: 语言编码 + pt_BR: Language Code + human_description: + en_US: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html + zh_Hans: 语言编码,例如zh-CN, en-US 可参考 https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html + pt_BR: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html + llm_description: The language code used to create your transcription job. + form: llm + - name: identify_language + type: boolean + default: true + required: false + label: + en_US: Automactically Identify Language + zh_Hans: 自动识别语言 + pt_BR: Automactically Identify Language + human_description: + en_US: Automactically Identify Language + zh_Hans: 自动识别语言 + pt_BR: Automactically Identify Language + llm_description: Enable Automactically Identify Language + form: form + - name: identify_multiple_languages + type: boolean + required: false + label: + en_US: Automactically Identify Multiple Languages + zh_Hans: 自动识别多种语言 + pt_BR: Automactically Identify Multiple Languages + human_description: + en_US: Automactically Identify Multiple Languages + zh_Hans: 自动识别多种语言 + pt_BR: Automactically Identify Multiple Languages + llm_description: Enable Automactically Identify Multiple Languages + form: form + - name: language_options + type: string + required: false + label: + en_US: Language Options + zh_Hans: 语言种类选项 + pt_BR: Language Options + human_description: + en_US: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media + zh_Hans: 您可以指定两个或更多的语言代码来表示您认为可能出现在媒体中的语言。用|分隔,如 zh-CN|en-US + pt_BR: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media + llm_description: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media + form: llm + - name: s3_bucket_name + type: string + required: true + label: + en_US: s3 bucket name + zh_Hans: s3 存储桶名称 + pt_BR: s3 bucket name + human_description: + en_US: s3 bucket name to store transcribe files (don't add prefix s3://) + zh_Hans: s3 存储桶名称,用于存储转录文件 (不需要前缀 s3://) + pt_BR: s3 bucket name to store transcribe files (don't add prefix s3://) + llm_description: s3 bucket name to store transcribe files + form: form + - name: ShowSpeakerLabels + type: boolean + required: true + default: true + label: + en_US: ShowSpeakerLabels + zh_Hans: 显示说话人标签 + pt_BR: ShowSpeakerLabels + human_description: + en_US: Enables speaker partitioning (diarization) in your transcription output + zh_Hans: 在转录输出中启用说话人分区(说话人分离) + pt_BR: Enables speaker partitioning (diarization) in your transcription output + llm_description: Enables speaker partitioning (diarization) in your transcription output + form: form + - name: MaxSpeakerLabels + type: number + required: true + default: 2 + label: + en_US: MaxSpeakerLabels + zh_Hans: 说话人标签数量 + pt_BR: MaxSpeakerLabels + human_description: + en_US: Specify the maximum number of speakers you want to partition in your media + zh_Hans: 指定您希望在媒体中划分的最多演讲者数量。 + pt_BR: Specify the maximum number of speakers you want to partition in your media + llm_description: Specify the maximum number of speakers you want to partition in your media + form: form + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: Please enter the AWS region for the transcribe service, for example 'us-east-1'. + zh_Hans: 请输入Transcribe的 AWS 区域,例如 'us-east-1'。 + llm_description: Please enter the AWS region for the transcribe service, for example 'us-east-1'. + form: form