Skip to content

Commit

Permalink
ran lint
Browse files Browse the repository at this point in the history
  • Loading branch information
snova-rodrigom committed Dec 20, 2024
1 parent 2d6efd2 commit ecc4933
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
7 changes: 2 additions & 5 deletions api/core/model_runtime/model_providers/sambanova/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

import json


class SambanovaLargeLanguageModel(LargeLanguageModel):
def _invoke(
Expand Down Expand Up @@ -102,7 +100,7 @@ def _chat_generate(
stream=stream,
**model_parameters,
)

if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
Expand Down Expand Up @@ -352,7 +350,6 @@ def _handle_chat_generate_stream_response(

yield final_chunk


def _to_credential_kwargs(self, credentials: dict) -> dict:
"""
Transform credentials to kwargs for model instance
Expand All @@ -375,7 +372,7 @@ def _convert_message_to_dict(self, message: PromptMessage) -> dict[str, Any]:
messages_dict: role / content dict
"""
message_dict: dict[str, Any] = {}

if isinstance(message, SystemPromptMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, UserPromptMessage):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_validate_credentials():
model = SambanovaLargeLanguageModel()

with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(model='Meta-Llama-3.1-8B-Instruct', credentials={"sambanova_api_key": "invalid_key"})
model.validate_credentials(model="Meta-Llama-3.1-8B-Instruct", credentials={"sambanova_api_key": "invalid_key"})

model.validate_credentials(
model="Meta-Llama-3.1-8B-Instruct", credentials={"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY")}
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True


def test_get_num_tokens():
model = SambanovaLargeLanguageModel()
Expand All @@ -88,7 +88,7 @@ def test_get_num_tokens():
),
UserPromptMessage(content="Hello World!"),
],
model_parameters={"temperature": 0.0}
model_parameters={"temperature": 0.0},
)

assert num_tokens == 25
assert num_tokens == 25
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ def test_validate_provider_credentials():
provider.validate_provider_credentials(credentials={"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY")})


test_validate_provider_credentials()
test_validate_provider_credentials()

0 comments on commit ecc4933

Please sign in to comment.