Skip to content

Commit

Permalink
Merge branch 'main' into p56
Browse files Browse the repository at this point in the history
  • Loading branch information
hjlarry authored Sep 9, 2024
2 parents bf326ec + a771eea commit d4bc394
Show file tree
Hide file tree
Showing 97 changed files with 1,941 additions and 314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def _invoke(self, model: str, credentials: dict,
endpoint_url,
headers=headers,
data=json.dumps(payload),
timeout=(10, 300)
timeout=(10, 300),
options={"use_mmap": "true"}
)

response.raise_for_status() # Raise an exception for HTTP errors
Expand Down
317 changes: 274 additions & 43 deletions api/core/model_runtime/model_providers/sagemaker/llm/llm.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class SageMakerRerankModel(RerankModel):
"""
Model class for Cohere rerank model.
Model class for SageMaker rerank model.
"""
sagemaker_client: Any = None

Expand Down
28 changes: 27 additions & 1 deletion api/core/model_runtime/model_providers/sagemaker/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import uuid
from typing import IO, Any

from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class SageMakerProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Expand All @@ -15,3 +16,28 @@ def validate_provider_credentials(self, credentials: dict) -> None:
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
pass

def buffer_to_s3(s3_client:Any, file: IO[bytes], bucket:str, s3_prefix:str) -> str:
'''
return s3_uri of this file
'''
s3_key = f'{s3_prefix}{uuid.uuid4()}.mp3'
s3_client.put_object(
Body=file.read(),
Bucket=bucket,
Key=s3_key,
ContentType='audio/mp3'
)
return s3_key

def generate_presigned_url(s3_client:Any, file: IO[bytes], bucket_name:str, s3_prefix:str, expiration=600) -> str:
object_key = buffer_to_s3(s3_client, file, bucket_name, s3_prefix)
try:
response = s3_client.generate_presigned_url('get_object',
Params={'Bucket': bucket_name, 'Key': object_key},
ExpiresIn=expiration)
except Exception as e:
print(f"Error generating presigned URL: {e}")
return None

return response
78 changes: 73 additions & 5 deletions api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ supported_model_types:
- llm
- text-embedding
- rerank
- speech2text
- tts
configurate_methods:
- customizable-model
model_credential_schema:
Expand All @@ -45,14 +47,10 @@ model_credential_schema:
zh_Hans: 选择对话类型
en_US: Select completion mode
options:
- value: completion
label:
en_US: Completion
zh_Hans: 补全
- value: chat
label:
en_US: Chat
zh_Hans: 对话
zh_Hans: Chat
- variable: sagemaker_endpoint
label:
en_US: sagemaker endpoint
Expand All @@ -61,6 +59,76 @@ model_credential_schema:
placeholder:
zh_Hans: 请输出你的Sagemaker推理端点
en_US: Enter your Sagemaker Inference endpoint
- variable: audio_s3_cache_bucket
show_on:
- variable: __model_type
value: speech2text
label:
zh_Hans: 音频缓存桶(s3 bucket)
en_US: audio cache bucket(s3 bucket)
type: text-input
required: true
placeholder:
zh_Hans: sagemaker-us-east-1-******207838
en_US: sagemaker-us-east-1-*******7838
- variable: audio_model_type
show_on:
- variable: __model_type
value: tts
label:
en_US: Audio model type
type: select
required: true
placeholder:
zh_Hans: 语音模型类型
en_US: Audio model type
options:
- value: PresetVoice
label:
en_US: preset voice
zh_Hans: 内置音色
- value: CloneVoice
label:
en_US: clone voice
zh_Hans: 克隆音色
- value: CloneVoice_CrossLingual
label:
en_US: crosslingual clone voice
zh_Hans: 跨语种克隆音色
- value: InstructVoice
label:
en_US: Instruct voice
zh_Hans: 文字指令音色
- variable: prompt_audio
show_on:
- variable: __model_type
value: tts
label:
en_US: Mock Audio Source
type: text-input
required: false
placeholder:
zh_Hans: 被模仿的音色音频
en_US: source audio to be mocked
- variable: prompt_text
show_on:
- variable: __model_type
value: tts
label:
en_US: Prompt Audio Text
type: text-input
required: false
placeholder:
zh_Hans: 模仿音色的对应文本
en_US: text for the mocked source audio
- variable: instruct_text
show_on:
- variable: __model_type
value: tts
label:
en_US: instruct text for speaker
type: text-input
required: false
- variable: aws_access_key_id
required: false
label:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import logging
from typing import IO, Any, Optional

import boto3

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url

logger = logging.getLogger(__name__)

class SageMakerSpeech2TextModel(Speech2TextModel):
"""
Model class for Xinference speech to text model.
"""
sagemaker_client: Any = None
s3_client : Any = None

def _invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
-> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
asr_text = None

try:
if not self.sagemaker_client:
access_key = credentials.get('aws_access_key_id')
secret_key = credentials.get('aws_secret_access_key')
aws_region = credentials.get('aws_region')
if aws_region:
if access_key and secret_key:
self.sagemaker_client = boto3.client("sagemaker-runtime",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
self.s3_client = boto3.client("s3",
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.sagemaker_client = boto3.client("sagemaker-runtime")
self.s3_client = boto3.client("s3")

s3_prefix='dify/speech2text/'
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
bucket = credentials.get('audio_s3_cache_bucket')

s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
payload = {
"audio_s3_presign_uri" : s3_presign_url
}

response_model = self.sagemaker_client.invoke_endpoint(
EndpointName=sagemaker_endpoint,
Body=json.dumps(payload),
ContentType="application/json"
)
json_str = response_model['Body'].read().decode('utf8')
json_obj = json.loads(json_str)
asr_text = json_obj['text']
except Exception as e:
logger.exception(f'Exception {e}, line : {line}')

return asr_text

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
pass

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
InvokeConnectionError
],
InvokeServerUnavailableError: [
InvokeServerUnavailableError
],
InvokeRateLimitError: [
InvokeRateLimitError
],
InvokeAuthorizationError: [
InvokeAuthorizationError
],
InvokeBadRequestError: [
InvokeBadRequestError,
KeyError,
ValueError
]
}

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
used to define customizable model schema
"""
entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.SPEECH2TEXT,
model_properties={ },
parameter_rules=[]
)

return entity
Empty file.
Loading

0 comments on commit d4bc394

Please sign in to comment.