Skip to content

Commit

Permalink
feat: add support of speech2text function for OpenAI-API-compatible a…
Browse files Browse the repository at this point in the history
…nd Siliconflow (#7197)
  • Loading branch information
alfredcai authored Aug 12, 2024
1 parent 57ce844 commit a12ddc4
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ description:
supported_model_types:
- llm
- text-embedding
- speech2text
configurate_methods:
- customizable-model
model_credential_schema:
Expand Down Expand Up @@ -61,6 +62,22 @@ model_credential_schema:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
show_on:
- variable: __model_type
value: llm
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
show_on:
- variable: __model_type
value: text-embedding
type: text-input
default: '4096'
placeholder:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import IO, Optional
from urllib.parse import urljoin

import requests

from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat


class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
"""
Model class for OpenAI Compatible Speech to text model.
"""

def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
headers = {}

api_key = credentials.get("api_key")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"

endpoint_url = credentials.get("endpoint_url")
if not endpoint_url.endswith("/"):
endpoint_url += "/"
endpoint_url = urljoin(endpoint_url, "audio/transcriptions")

payload = {"model": model}
files = [("file", file)]
response = requests.post(endpoint_url, headers=headers, data=payload, files=files)

if response.status_code != 200:
raise InvokeBadRequestError(response.text)
response_data = response.json()
return response_data["text"]

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
audio_file_path = self._get_demo_file_path()

with open(audio_file_path, "rb") as audio_file:
self._invoke(model, credentials, audio_file)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

logger = logging.getLogger(__name__)


class SiliconflowProvider(ModelProvider):

def validate_provider_credentials(self, credentials: dict) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ help:
supported_model_types:
- llm
- text-embedding
- speech2text
configurate_methods:
- predefined-model
provider_credential_schema:
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model: iic/SenseVoiceSmall
model_type: speech2text
model_properties:
file_upload_limit: 1
supported_file_extensions: mp3,wav
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import IO, Optional

from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel


class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel):
"""
Model class for Siliconflow Speech to text model.
"""

def _invoke(
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None
) -> str:
"""
Invoke speech2text model
:param model: model name
:param credentials: model credentials
:param file: audio file
:param user: unique user id
:return: text for given audio file
"""
self._add_custom_parameters(credentials)
return super()._invoke(model, credentials, file)

def validate_credentials(self, model: str, credentials: dict) -> None:
self._add_custom_parameters(credentials)
return super().validate_credentials(model, credentials)

@classmethod
def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import (
OAICompatSpeech2TextModel,
)


def test_validate_credentials():
model = OAICompatSpeech2TextModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="whisper-1",
credentials={
"api_key": "invalid_key",
"endpoint_url": "https://api.openai.com/v1/"
},
)

model.validate_credentials(
model="whisper-1",
credentials={
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/"
},
)


def test_invoke_model():
model = OAICompatSpeech2TextModel()

# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))

# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")

# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")

# Open the file and get the file object
with open(audio_file_path, "rb") as audio_file:
file = audio_file

result = model.invoke(
model="whisper-1",
credentials={
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/"
},
file=file,
user="abc-123",
)

assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel


def test_validate_credentials():
model = SiliconflowSpeech2TextModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="iic/SenseVoiceSmall",
credentials={
"api_key": "invalid_key"
},
)

model.validate_credentials(
model="iic/SenseVoiceSmall",
credentials={
"api_key": os.environ.get("API_KEY")
},
)


def test_invoke_model():
model = SiliconflowSpeech2TextModel()

# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))

# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")

# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, "audio.mp3")

# Open the file and get the file object
with open(audio_file_path, "rb") as audio_file:
file = audio_file

result = model.invoke(
model="iic/SenseVoiceSmall",
credentials={
"api_key": os.environ.get("API_KEY")
},
file=file
)

assert isinstance(result, str)
assert result == '1,2,3,4,5,6,7,8,9,10.'

0 comments on commit a12ddc4

Please sign in to comment.