diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 286640079b02a9..0c9d08679ad046 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -59,6 +59,7 @@ from core.model_runtime.model_providers.xinference.xinference_helper import ( XinferenceHelper, XinferenceModelExtraParameter, + validate_model_uid, ) from core.model_runtime.utils import helper @@ -114,7 +115,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: } """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") extra_param = XinferenceHelper.get_xinference_extra_parameter( diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 8f18bc42d2339d..6368cd76dc97bf 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -15,6 +15,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid class XinferenceRerankModel(RerankModel): @@ -77,10 +78,7 @@ def _invoke( ) # score threshold check - if score_threshold is not None: - if result["relevance_score"] >= score_threshold: - rerank_documents.append(rerank_document) - else: + if score_threshold is None or result["relevance_score"] >= score_threshold: rerank_documents.append(rerank_document) return RerankResult(model=model, docs=rerank_documents) @@ -94,7 +92,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index a6c5b8a0a571e9..c5ad38391185a9 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -14,6 +14,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.xinference.xinference_helper import validate_model_uid class XinferenceSpeech2TextModel(Speech2TextModel): @@ -42,7 +43,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 16272391320d55..ddc21b365c1b78 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -17,7 +17,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid class XinferenceTextEmbeddingModel(TextEmbeddingModel): @@ -110,7 +110,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") server_url = credentials["server_url"] diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index 81dbe397d2f10c..6290e8551d8020 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -15,7 +15,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel -from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper, validate_model_uid class XinferenceText2SpeechModel(TTSModel): @@ -70,7 +70,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - if "/" in credentials["model_uid"] or "?" in credentials["model_uid"] or "#" in credentials["model_uid"]: + if not validate_model_uid(credentials): raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") credentials["server_url"] = credentials["server_url"].removesuffix("/") diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 619ee1492a9272..baa3ccbe8adbc0 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -132,3 +132,16 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: st context_length=context_length, model_family=model_family, ) + + +def validate_model_uid(credentials: dict) -> bool: + """ + Validate the model_uid within the credentials dictionary to ensure it does not + contain forbidden characters ("/", "?", "#"). + + param credentials: model credentials + :return: True if the model_uid does not contain forbidden characters ("/", "?", "#"), else False. + """ + forbidden_characters = ["/", "?", "#"] + model_uid = credentials.get("model_uid", "") + return not any(char in forbidden_characters for char in model_uid)