Skip to content

Commit

Permalink
refactor: optimize the calculation of rerank threshold and the logic …
Browse files Browse the repository at this point in the history
…for forbidden characters in model_uid
  • Loading branch information
hwzhuhao committed Sep 29, 2024
1 parent 74f58f2 commit 3a190ac
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 11 deletions.
3 changes: 2 additions & 1 deletion api/core/model_runtime/model_providers/xinference/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from core.model_runtime.model_providers.xinference.xinference_helper import (
XinferenceHelper,
XinferenceModelExtraParameter,
validate_model_uid,
)
from core.model_runtime.utils import helper

Expand Down Expand Up @@ -114,7 +115,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
}
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

extra_param = XinferenceHelper.get_xinference_extra_parameter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid


class XinferenceRerankModel(RerankModel):
Expand Down Expand Up @@ -77,10 +78,7 @@ def _invoke(
)

# score threshold check
if score_threshold is not None:
if result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
else:
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)

return RerankResult(model=model, docs=rerank_documents)
Expand All @@ -94,7 +92,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

credentials["server_url"] = credentials["server_url"].removesuffix("/")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid


class XinferenceSpeech2TextModel(Speech2TextModel):
Expand Down Expand Up @@ -42,7 +43,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

credentials["server_url"] = credentials["server_url"].removesuffix("/")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid


class XinferenceTextEmbeddingModel(TextEmbeddingModel):
Expand Down Expand Up @@ -110,7 +110,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

server_url = credentials["server_url"]
Expand Down
4 changes: 2 additions & 2 deletions api/core/model_runtime/model_providers/xinference/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper
from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid


class XinferenceText2SpeechModel(TTSModel):
Expand Down Expand Up @@ -70,7 +70,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
:return:
"""
try:
if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]:
if not validate_model_uid(credentials):
raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #")

credentials["server_url"] = credentials["server_url"].removesuffix("/")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,16 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: st
context_length=context_length,
model_family=model_family,
)


def validate_model_uid(credentials: dict) -> bool:
"""
Validate the model_uid within the credentials dictionary to ensure it does not
contain forbidden characters ("/", "?", "#").
param credentials: model credentials
:return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False.
"""
forbidden_characters = ["/", "?", "#"]
model_uid = credentials.get("model_uid", "")
return not any(char in forbidden_characters for char in model_uid)

0 comments on commit 3a190ac

Please sign in to comment.