Skip to content

Commit

Permalink
feat:use xinference tts stream mode (langgenius#8616)
Browse files Browse the repository at this point in the history
  • Loading branch information
leslie2046 authored and JunXu01 committed Nov 9, 2024
1 parent 2d95cb2 commit fd75683
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 17 deletions.
3 changes: 1 addition & 2 deletions api/core/model_runtime/model_providers/xinference/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from openai.types.completion import Completion
from xinference_client.client.restful.restful_client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulGenerateModelHandle,
)
Expand Down Expand Up @@ -491,7 +490,7 @@ def _generate(
if tools and len(tools) > 0:
generate_config["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools]
vision = credentials.get("support_vision", False)
if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle):
if isinstance(xinference_model, RESTfulChatModelHandle):
resp = client.chat.completions.create(
model=credentials["model_uid"],
messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages],
Expand Down
12 changes: 6 additions & 6 deletions api/core/model_runtime/model_providers/xinference/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,21 +208,21 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences)))
futures = [
executor.submit(
handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=False
handle.speech, input=sentences[i], voice=voice, response_format="mp3", speed=1.0, stream=True
)
for i in range(len(sentences))
]

for future in futures:
response = future.result()
for i in range(0, len(response), 1024):
yield response[i : i + 1024]
for chunk in response:
yield chunk
else:
response = handle.speech(
input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=False
input=content_text.strip(), voice=voice, response_format="mp3", speed=1.0, stream=True
)

for i in range(0, len(response), 1024):
yield response[i : i + 1024]
for chunk in response:
yield chunk
except Exception as ex:
raise InvokeBadRequestError(str(ex))
8 changes: 4 additions & 4 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ transformers = "~4.35.0"
unstructured = { version = "~0.10.27", extras = ["docx", "epub", "md", "msg", "ppt", "pptx"] }
websocket-client = "~1.7.0"
werkzeug = "~3.0.1"
xinference-client = "0.13.3"
xinference-client = "0.15.2"
yarl = "~1.9.4"
zhipuai = "1.0.7"
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from requests.sessions import Session
from xinference_client.client.restful.restful_client import (
Client,
RESTfulChatglmCppChatModelHandle,
RESTfulChatModelHandle,
RESTfulEmbeddingModelHandle,
RESTfulGenerateModelHandle,
Expand All @@ -19,9 +18,7 @@


class MockXinferenceClass:
def get_chat_model(
self: Client, model_uid: str
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found")

Expand Down

0 comments on commit fd75683

Please sign in to comment.