-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add support of speech2text function for OpenAI-API-compatible a…
…nd Siliconflow (#7197)
- Loading branch information
Showing
10 changed files
with
231 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
63 changes: 63 additions & 0 deletions
63
api/core/model_runtime/model_providers/openai_api_compatible/speech2text/speech2text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from typing import IO, Optional | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
|
||
from core.model_runtime.errors.invoke import InvokeBadRequestError | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel | ||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat | ||
|
||
|
||
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): | ||
""" | ||
Model class for OpenAI Compatible Speech to text model. | ||
""" | ||
|
||
def _invoke( | ||
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None | ||
) -> str: | ||
""" | ||
Invoke speech2text model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param file: audio file | ||
:param user: unique user id | ||
:return: text for given audio file | ||
""" | ||
headers = {} | ||
|
||
api_key = credentials.get("api_key") | ||
if api_key: | ||
headers["Authorization"] = f"Bearer {api_key}" | ||
|
||
endpoint_url = credentials.get("endpoint_url") | ||
if not endpoint_url.endswith("/"): | ||
endpoint_url += "/" | ||
endpoint_url = urljoin(endpoint_url, "audio/transcriptions") | ||
|
||
payload = {"model": model} | ||
files = [("file", file)] | ||
response = requests.post(endpoint_url, headers=headers, data=payload, files=files) | ||
|
||
if response.status_code != 200: | ||
raise InvokeBadRequestError(response.text) | ||
response_data = response.json() | ||
return response_data["text"] | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
try: | ||
audio_file_path = self._get_demo_file_path() | ||
|
||
with open(audio_file_path, "rb") as audio_file: | ||
self._invoke(model, credentials, audio_file) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
5 changes: 5 additions & 0 deletions
5
api/core/model_runtime/model_providers/siliconflow/speech2text/sense-voice-small.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
model: iic/SenseVoiceSmall | ||
model_type: speech2text | ||
model_properties: | ||
file_upload_limit: 1 | ||
supported_file_extensions: mp3,wav |
32 changes: 32 additions & 0 deletions
32
api/core/model_runtime/model_providers/siliconflow/speech2text/speech2text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from typing import IO, Optional | ||
|
||
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel | ||
|
||
|
||
class SiliconflowSpeech2TextModel(OAICompatSpeech2TextModel): | ||
""" | ||
Model class for Siliconflow Speech to text model. | ||
""" | ||
|
||
def _invoke( | ||
self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None | ||
) -> str: | ||
""" | ||
Invoke speech2text model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param file: audio file | ||
:param user: unique user id | ||
:return: text for given audio file | ||
""" | ||
self._add_custom_parameters(credentials) | ||
return super()._invoke(model, credentials, file) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
self._add_custom_parameters(credentials) | ||
return super().validate_credentials(model, credentials) | ||
|
||
@classmethod | ||
def _add_custom_parameters(cls, credentials: dict) -> None: | ||
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" |
59 changes: 59 additions & 0 deletions
59
api/tests/integration_tests/model_runtime/openai_api_compatible/test_speech2text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import ( | ||
OAICompatSpeech2TextModel, | ||
) | ||
|
||
|
||
def test_validate_credentials(): | ||
model = OAICompatSpeech2TextModel() | ||
|
||
with pytest.raises(CredentialsValidateFailedError): | ||
model.validate_credentials( | ||
model="whisper-1", | ||
credentials={ | ||
"api_key": "invalid_key", | ||
"endpoint_url": "https://api.openai.com/v1/" | ||
}, | ||
) | ||
|
||
model.validate_credentials( | ||
model="whisper-1", | ||
credentials={ | ||
"api_key": os.environ.get("OPENAI_API_KEY"), | ||
"endpoint_url": "https://api.openai.com/v1/" | ||
}, | ||
) | ||
|
||
|
||
def test_invoke_model(): | ||
model = OAICompatSpeech2TextModel() | ||
|
||
# Get the directory of the current file | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# Get assets directory | ||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets") | ||
|
||
# Construct the path to the audio file | ||
audio_file_path = os.path.join(assets_dir, "audio.mp3") | ||
|
||
# Open the file and get the file object | ||
with open(audio_file_path, "rb") as audio_file: | ||
file = audio_file | ||
|
||
result = model.invoke( | ||
model="whisper-1", | ||
credentials={ | ||
"api_key": os.environ.get("OPENAI_API_KEY"), | ||
"endpoint_url": "https://api.openai.com/v1/" | ||
}, | ||
file=file, | ||
user="abc-123", | ||
) | ||
|
||
assert isinstance(result, str) | ||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' |
53 changes: 53 additions & 0 deletions
53
api/tests/integration_tests/model_runtime/siliconflow/test_speech2text.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.siliconflow.speech2text.speech2text import SiliconflowSpeech2TextModel | ||
|
||
|
||
def test_validate_credentials(): | ||
model = SiliconflowSpeech2TextModel() | ||
|
||
with pytest.raises(CredentialsValidateFailedError): | ||
model.validate_credentials( | ||
model="iic/SenseVoiceSmall", | ||
credentials={ | ||
"api_key": "invalid_key" | ||
}, | ||
) | ||
|
||
model.validate_credentials( | ||
model="iic/SenseVoiceSmall", | ||
credentials={ | ||
"api_key": os.environ.get("API_KEY") | ||
}, | ||
) | ||
|
||
|
||
def test_invoke_model(): | ||
model = SiliconflowSpeech2TextModel() | ||
|
||
# Get the directory of the current file | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
# Get assets directory | ||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets") | ||
|
||
# Construct the path to the audio file | ||
audio_file_path = os.path.join(assets_dir, "audio.mp3") | ||
|
||
# Open the file and get the file object | ||
with open(audio_file_path, "rb") as audio_file: | ||
file = audio_file | ||
|
||
result = model.invoke( | ||
model="iic/SenseVoiceSmall", | ||
credentials={ | ||
"api_key": os.environ.get("API_KEY") | ||
}, | ||
file=file | ||
) | ||
|
||
assert isinstance(result, str) | ||
assert result == '1,2,3,4,5,6,7,8,9,10.' |