From aee698854abb75586fdbad20ab0543313267b3d9 Mon Sep 17 00:00:00 2001 From: ethanhsu Date: Sat, 17 Aug 2024 12:00:15 +0800 Subject: [PATCH] feat: support xinference's auth system --- .../model_providers/xinference/llm/llm.py | 10 +++++++--- .../model_providers/xinference/rerank/rerank.py | 3 ++- .../xinference/speech2text/speech2text.py | 3 ++- .../xinference/text_embedding/text_embedding.py | 12 ++++++++++-- .../model_providers/xinference/tts/tts.py | 9 +++++++-- .../model_providers/xinference/xinference_helper.py | 9 +++++---- 6 files changed, 33 insertions(+), 13 deletions(-) diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 988bb0ce4432df..4760e8f1185e8d 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -85,7 +85,8 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes tools=tools, stop=stop, stream=stream, user=user, extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) ) @@ -106,7 +107,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'completion_type' not in credentials: if 'chat' in extra_param.model_ability: @@ -396,7 +398,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode else: extra_args = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key') ) if 'chat' in extra_args.model_ability: @@ -464,6 +467,7 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM xinference_client = Client( base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_model = xinference_client.get_model(credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 4e7543fd996fd7..d809537479f40d 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -108,7 +108,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py index 9ee36213176ef7..62b77f22e59c87 100644 --- a/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/xinference/speech2text/speech2text.py @@ -52,7 +52,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: # initialize client client = Client( - base_url=credentials['server_url'] + base_url=credentials['server_url'], + api_key=credentials.get('api_key'), ) xinference_client = client.get_model(model_uid=credentials['model_uid']) diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 11f1e29cb39f81..3a8d704c25838c 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -110,14 +110,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None: server_url = credentials['server_url'] model_uid = credentials['model_uid'] - extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + api_key = credentials.get('api_key') + extra_args = XinferenceHelper.get_xinference_extra_parameter( + server_url=server_url, + model_uid=model_uid, + api_key=api_key, + ) if extra_args.max_tokens: credentials['max_tokens'] = extra_args.max_tokens if server_url.endswith('/'): server_url = server_url[:-1] - client = Client(base_url=server_url) + client = Client( + base_url=server_url, + api_key=api_key, + ) try: handle = client.get_model(model_uid=model_uid) diff --git a/api/core/model_runtime/model_providers/xinference/tts/tts.py b/api/core/model_runtime/model_providers/xinference/tts/tts.py index a564a021b19615..bfa752df8cdb31 100644 --- a/api/core/model_runtime/model_providers/xinference/tts/tts.py +++ b/api/core/model_runtime/model_providers/xinference/tts/tts.py @@ -81,7 +81,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], - model_uid=credentials['model_uid'] + model_uid=credentials['model_uid'], + api_key=credentials.get('api_key'), ) if 'text-to-audio' not in extra_param.model_ability: @@ -203,7 +204,11 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str credentials['server_url'] = credentials['server_url'][:-1] try: - handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) + api_key = credentials.get('api_key') + auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} + handle = RESTfulAudioModelHandle( + credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers + ) model_support_voice = [x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials)] diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 7db483a485ee1c..75161ad376c419 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -35,13 +35,13 @@ def __init__(self, model_format: str, model_handle_type: str, model_ability: lis class XinferenceHelper: @staticmethod - def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: XinferenceHelper._clean_cache() with cache_lock: if model_uid not in cache: cache[model_uid] = { 'expires': time() + 300, - 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) + 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key) } return cache[model_uid]['value'] @@ -56,7 +56,7 @@ def _clean_cache() -> None: pass @staticmethod - def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: + def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter: """ get xinference model extra parameter like model_format and model_handle_type """ @@ -70,9 +70,10 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen session = Session() session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3)) + headers = {'Authorization': f'Bearer {api_key}'} if api_key else {} try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') if response.status_code != 200: