Skip to content

Commit

Permalink
feat: support xinference's auth system
Browse files Browse the repository at this point in the history
  • Loading branch information
realethanhsu committed Aug 17, 2024
1 parent 4d4af00 commit aee6988
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 13 deletions.
10 changes: 7 additions & 3 deletions api/core/model_runtime/model_providers/xinference/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMes
tools=tools, stop=stop, stream=stream, user=user,
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
)
)

Expand All @@ -106,7 +107,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:

extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
)
if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability:
Expand Down Expand Up @@ -396,7 +398,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
else:
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
)

if 'chat' in extra_args.model_ability:
Expand Down Expand Up @@ -464,6 +467,7 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM

xinference_client = Client(
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)

xinference_model = xinference_client.get_model(credentials['model_uid'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:

# initialize client
client = Client(
base_url=credentials['server_url']
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)

xinference_client = client.get_model(model_uid=credentials['model_uid'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:

# initialize client
client = Client(
base_url=credentials['server_url']
base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
)

xinference_client = client.get_model(model_uid=credentials['model_uid'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,22 @@ def validate_credentials(self, model: str, credentials: dict) -> None:

server_url = credentials['server_url']
model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid)
api_key = credentials.get('api_key')
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=server_url,
model_uid=model_uid,
api_key=api_key,
)

if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'):
server_url = server_url[:-1]

client = Client(base_url=server_url)
client = Client(
base_url=server_url,
api_key=api_key,
)

try:
handle = client.get_model(model_uid=model_uid)
Expand Down
9 changes: 7 additions & 2 deletions api/core/model_runtime/model_providers/xinference/tts/tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:

extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'],
model_uid=credentials['model_uid']
model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
)

if 'text-to-audio' not in extra_param.model_ability:
Expand Down Expand Up @@ -203,7 +204,11 @@ def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str
credentials['server_url'] = credentials['server_url'][:-1]

try:
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={})
api_key = credentials.get('api_key')
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
)

model_support_voice = [x.get("value") for x in
self.get_tts_model_voices(model=model, credentials=credentials)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(self, model_format: str, model_handle_type: str, model_ability: lis

class XinferenceHelper:
@staticmethod
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache()
with cache_lock:
if model_uid not in cache:
cache[model_uid] = {
'expires': time() + 300,
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid)
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
}
return cache[model_uid]['value']

Expand All @@ -56,7 +56,7 @@ def _clean_cache() -> None:
pass

@staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
"""
get xinference model extra parameter like model_format and model_handle_type
"""
Expand All @@ -70,9 +70,10 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen
session = Session()
session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3))
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}

try:
response = session.get(url, timeout=10)
response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200:
Expand Down

0 comments on commit aee6988

Please sign in to comment.