forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: support fish audio TTS (langgenius#7982)
- Loading branch information
1 parent
98ab10d
commit a463df1
Showing
12 changed files
with
433 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
1 change: 1 addition & 0 deletions
1
api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_l_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions
1
api/core/model_runtime/model_providers/fishaudio/_assets/fishaudio_s_en.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions
28
api/core/model_runtime/model_providers/fishaudio/fishaudio.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import logging | ||
|
||
from core.model_runtime.entities.model_entities import ModelType | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class FishAudioProvider(ModelProvider): | ||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
""" | ||
Validate provider credentials | ||
For debugging purposes, this method now always passes validation. | ||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`. | ||
""" | ||
try: | ||
model_instance = self.get_model_instance(ModelType.TTS) | ||
model_instance.validate_credentials( | ||
credentials=credentials | ||
) | ||
except CredentialsValidateFailedError as ex: | ||
raise ex | ||
except Exception as ex: | ||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') | ||
raise ex |
76 changes: 76 additions & 0 deletions
76
api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
provider: fishaudio | ||
label: | ||
en_US: Fish Audio | ||
description: | ||
en_US: Models provided by Fish Audio, currently only support TTS. | ||
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。 | ||
icon_small: | ||
en_US: fishaudio_s_en.svg | ||
icon_large: | ||
en_US: fishaudio_l_en.svg | ||
background: "#E5E7EB" | ||
help: | ||
title: | ||
en_US: Get your API key from Fish Audio | ||
zh_Hans: 从 Fish Audio 获取你的 API Key | ||
url: | ||
en_US: https://fish.audio/go-api/ | ||
supported_model_types: | ||
- tts | ||
configurate_methods: | ||
- predefined-model | ||
provider_credential_schema: | ||
credential_form_schemas: | ||
- variable: api_key | ||
label: | ||
en_US: API Key | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 API Key | ||
en_US: Enter your API Key | ||
- variable: api_base | ||
label: | ||
en_US: API URL | ||
type: text-input | ||
required: false | ||
default: https://api.fish.audio | ||
placeholder: | ||
en_US: Enter your API URL | ||
zh_Hans: 在此输入您的 API URL | ||
- variable: use_public_models | ||
label: | ||
en_US: Use Public Models | ||
type: select | ||
required: false | ||
default: "false" | ||
placeholder: | ||
en_US: Toggle to use public models | ||
zh_Hans: 切换以使用公共模型 | ||
options: | ||
- value: "true" | ||
label: | ||
en_US: Allow Public Models | ||
zh_Hans: 使用公共模型 | ||
- value: "false" | ||
label: | ||
en_US: Private Models Only | ||
zh_Hans: 仅使用私有模型 | ||
- variable: latency | ||
label: | ||
en_US: Latency | ||
type: select | ||
required: false | ||
default: "normal" | ||
placeholder: | ||
en_US: Toggle to choice latency | ||
zh_Hans: 切换以调整延迟 | ||
options: | ||
- value: "balanced" | ||
label: | ||
en_US: Low (may affect quality) | ||
zh_Hans: 低延迟 (可能降低质量) | ||
- value: "normal" | ||
label: | ||
en_US: Normal | ||
zh_Hans: 标准 |
Empty file.
174 changes: 174 additions & 0 deletions
174
api/core/model_runtime/model_providers/fishaudio/tts/tts.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from typing import Optional | ||
|
||
import httpx | ||
|
||
from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.tts_model import TTSModel | ||
|
||
|
||
class FishAudioText2SpeechModel(TTSModel): | ||
""" | ||
Model class for Fish.audio Text to Speech model. | ||
""" | ||
|
||
def get_tts_model_voices( | ||
self, model: str, credentials: dict, language: Optional[str] = None | ||
) -> list: | ||
api_base = credentials.get("api_base", "https://api.fish.audio") | ||
api_key = credentials.get("api_key") | ||
use_public_models = credentials.get("use_public_models", "false") == "true" | ||
|
||
params = { | ||
"self": str(not use_public_models).lower(), | ||
"page_size": "100", | ||
} | ||
|
||
if language is not None: | ||
if "-" in language: | ||
language = language.split("-")[0] | ||
params["language"] = language | ||
|
||
results = httpx.get( | ||
f"{api_base}/model", | ||
headers={"Authorization": f"Bearer {api_key}"}, | ||
params=params, | ||
) | ||
|
||
results.raise_for_status() | ||
data = results.json() | ||
|
||
return [{"name": i["title"], "value": i["_id"]} for i in data["items"]] | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
tenant_id: str, | ||
credentials: dict, | ||
content_text: str, | ||
voice: str, | ||
user: Optional[str] = None, | ||
) -> any: | ||
""" | ||
Invoke text2speech model | ||
:param model: model name | ||
:param tenant_id: user tenant id | ||
:param credentials: model credentials | ||
:param voice: model timbre | ||
:param content_text: text content to be translated | ||
:param user: unique user id | ||
:return: generator yielding audio chunks | ||
""" | ||
|
||
return self._tts_invoke_streaming( | ||
model=model, | ||
credentials=credentials, | ||
content_text=content_text, | ||
voice=voice, | ||
) | ||
|
||
def validate_credentials( | ||
self, credentials: dict, user: Optional[str] = None | ||
) -> None: | ||
""" | ||
Validate credentials for text2speech model | ||
:param credentials: model credentials | ||
:param user: unique user id | ||
""" | ||
|
||
try: | ||
self.get_tts_model_voices( | ||
None, | ||
credentials={ | ||
"api_key": credentials["api_key"], | ||
"api_base": credentials["api_base"], | ||
# Disable public models will trigger a 403 error if user is not logged in | ||
"use_public_models": "false", | ||
}, | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
def _tts_invoke_streaming( | ||
self, model: str, credentials: dict, content_text: str, voice: str | ||
) -> any: | ||
""" | ||
Invoke streaming text2speech model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param content_text: text content to be translated | ||
:param voice: ID of the reference audio (if any) | ||
:return: generator yielding audio chunks | ||
""" | ||
|
||
try: | ||
word_limit = self._get_model_word_limit(model, credentials) | ||
if len(content_text) > word_limit: | ||
sentences = self._split_text_into_sentences( | ||
content_text, max_length=word_limit | ||
) | ||
else: | ||
sentences = [content_text.strip()] | ||
|
||
for i in range(len(sentences)): | ||
yield from self._tts_invoke_streaming_sentence( | ||
credentials=credentials, content_text=sentences[i], voice=voice | ||
) | ||
|
||
except Exception as ex: | ||
raise InvokeBadRequestError(str(ex)) | ||
|
||
def _tts_invoke_streaming_sentence( | ||
self, credentials: dict, content_text: str, voice: Optional[str] = None | ||
) -> any: | ||
""" | ||
Invoke streaming text2speech model | ||
:param credentials: model credentials | ||
:param content_text: text content to be translated | ||
:param voice: ID of the reference audio (if any) | ||
:return: generator yielding audio chunks | ||
""" | ||
api_key = credentials.get("api_key") | ||
api_url = credentials.get("api_base", "https://api.fish.audio") | ||
latency = credentials.get("latency") | ||
|
||
if not api_key: | ||
raise InvokeBadRequestError("API key is required") | ||
|
||
with httpx.stream( | ||
"POST", | ||
api_url + "/v1/tts", | ||
json={ | ||
"text": content_text, | ||
"reference_id": voice, | ||
"latency": latency | ||
}, | ||
headers={ | ||
"Authorization": f"Bearer {api_key}", | ||
}, | ||
timeout=None, | ||
) as response: | ||
if response.status_code != 200: | ||
raise InvokeBadRequestError( | ||
f"Error: {response.status_code} - {response.text}" | ||
) | ||
yield from response.iter_bytes() | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
""" | ||
Map model invoke error to unified error | ||
The key is the error type thrown to the caller | ||
The value is the error type thrown by the model, | ||
which needs to be converted into a unified error type for the caller. | ||
:return: Invoke error mapping | ||
""" | ||
return { | ||
InvokeBadRequestError: [ | ||
httpx.HTTPStatusError, | ||
], | ||
} |
5 changes: 5 additions & 0 deletions
5
api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
model: tts-default | ||
model_type: tts | ||
model_properties: | ||
word_limit: 1000 | ||
audio_type: 'mp3' |
82 changes: 82 additions & 0 deletions
82
api/tests/integration_tests/model_runtime/__mock/fishaudio.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
from collections.abc import Callable | ||
from typing import Literal | ||
|
||
import httpx | ||
import pytest | ||
from _pytest.monkeypatch import MonkeyPatch | ||
|
||
|
||
def mock_get(*args, **kwargs): | ||
if kwargs.get("headers", {}).get("Authorization") != "Bearer test": | ||
raise httpx.HTTPStatusError( | ||
"Invalid API key", | ||
request=httpx.Request("GET", ""), | ||
response=httpx.Response(401), | ||
) | ||
|
||
return httpx.Response( | ||
200, | ||
json={ | ||
"items": [ | ||
{"title": "Model 1", "_id": "model1"}, | ||
{"title": "Model 2", "_id": "model2"}, | ||
] | ||
}, | ||
request=httpx.Request("GET", ""), | ||
) | ||
|
||
|
||
def mock_stream(*args, **kwargs): | ||
class MockStreamResponse: | ||
def __init__(self): | ||
self.status_code = 200 | ||
|
||
def __enter__(self): | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_val, exc_tb): | ||
pass | ||
|
||
def iter_bytes(self): | ||
yield b"Mocked audio data" | ||
|
||
return MockStreamResponse() | ||
|
||
|
||
def mock_fishaudio( | ||
monkeypatch: MonkeyPatch, | ||
methods: list[Literal["list-models", "tts"]], | ||
) -> Callable[[], None]: | ||
""" | ||
mock fishaudio module | ||
:param monkeypatch: pytest monkeypatch fixture | ||
:return: unpatch function | ||
""" | ||
|
||
def unpatch() -> None: | ||
monkeypatch.undo() | ||
|
||
if "list-models" in methods: | ||
monkeypatch.setattr(httpx, "get", mock_get) | ||
|
||
if "tts" in methods: | ||
monkeypatch.setattr(httpx, "stream", mock_stream) | ||
|
||
return unpatch | ||
|
||
|
||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" | ||
|
||
|
||
@pytest.fixture | ||
def setup_fishaudio_mock(request, monkeypatch): | ||
methods = request.param if hasattr(request, "param") else [] | ||
if MOCK: | ||
unpatch = mock_fishaudio(monkeypatch, methods=methods) | ||
|
||
yield | ||
|
||
if MOCK: | ||
unpatch() |
Empty file.
Oops, something went wrong.