Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich committed Dec 19, 2024
1 parent 59d96f4 commit 1a5588a
Showing 1 changed file with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

def _convert_message_to_dict(message: ChatMessage) -> Dict[str, Any]:
"""Converts a ChatMessage to a dictionary with Role / content.
Args:
message: ChatMessage
Returns:
Expand All @@ -50,6 +51,7 @@ def _convert_message_to_dict(message: ChatMessage) -> Dict[str, Any]:

def _create_message_dicts(messages: Sequence[ChatMessage]) -> List[Dict[str, Any]]:
"""Converts a list of ChatMessages to a list of dictionaries with Role / content.
Args:
messages: list of ChatMessages
Returns:
Expand All @@ -59,8 +61,8 @@ def _create_message_dicts(messages: Sequence[ChatMessage]) -> List[Dict[str, Any


class SambaNovaCloud(LLM):
"""
SambaNova Cloud models.
"""SambaNova Cloud models.
Setup:
To use, you should have the environment variables:
`SAMBANOVA_URL` set with your SambaNova Cloud URL.
Expand Down Expand Up @@ -227,6 +229,7 @@ def _handle_request(
) -> Dict[str, Any]:
"""
Performs a post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Expand Down Expand Up @@ -271,6 +274,7 @@ async def _handle_request_async(
) -> Dict[str, Any]:
"""
Performs a async post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Expand Down Expand Up @@ -314,6 +318,7 @@ def _handle_streaming_request(
) -> Iterator[Dict]:
"""
Performs an streaming post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Expand Down Expand Up @@ -396,6 +401,7 @@ async def _handle_streaming_request_async(
) -> AsyncIterator[Dict]:
"""
Performs an async streaming post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Expand Down Expand Up @@ -465,6 +471,7 @@ def chat(
) -> ChatResponse:
"""
Calls the chat implementation of the SambaNovaCloud model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
Expand All @@ -473,6 +480,7 @@ def chat(
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
Returns:
ChatResponse with model generation
"""
Expand Down Expand Up @@ -502,6 +510,7 @@ def stream_chat(
) -> ChatResponseGen:
"""
Streams the chat output of the SambaNovaCloud model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
Expand All @@ -510,6 +519,7 @@ def stream_chat(
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
Yields:
ChatResponseGen with model partial generation
"""
Expand Down Expand Up @@ -572,6 +582,7 @@ async def achat(
) -> ChatResponse:
"""
Calls the async chat implementation of the SambaNovaCloud model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
Expand All @@ -580,6 +591,7 @@ async def achat(
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
Returns:
ChatResponse with async model generation
"""
Expand Down Expand Up @@ -627,8 +639,8 @@ def astream_complete(


class SambaStudio(LLM):
"""
SambaStudio model.
"""SambaStudio model.
Setup:
To use, you should have the environment variables:
``SAMBASTUDIO_URL`` set with your SambaStudio deployed endpoint URL.
Expand Down Expand Up @@ -852,6 +864,7 @@ def _messages_to_string(self, messages: Sequence[ChatMessage]) -> str:
"""Convert a sequence of ChatMessages to:
- dumped json string with Role / content dict structure when process_prompt is true,
- string with special tokens if process_prompt is false for generic V1 and V2 endpoints.
Args:
messages: sequence of ChatMessages
Returns:
Expand Down Expand Up @@ -884,6 +897,7 @@ def _messages_to_string(self, messages: Sequence[ChatMessage]) -> str:

def _get_sambastudio_urls(self, url: str) -> Tuple[str, str]:
"""Get streaming and non streaming URLs from the given URL.
Args:
url: string with sambastudio base or streaming endpoint url
Returns:
Expand Down Expand Up @@ -912,6 +926,7 @@ def _handle_request(
streaming: Optional[bool] = False,
) -> Response:
"""Performs a post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Expand Down Expand Up @@ -1018,6 +1033,7 @@ async def _handle_request_async(
streaming: Optional[bool] = False,
) -> Response:
"""Performs an async post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Expand Down Expand Up @@ -1127,6 +1143,7 @@ async def _handle_request_async(

def _process_response(self, response: Response) -> ChatMessage:
"""Process a non streaming response from the api.
Args:
response: A request Response object
Returns:
Expand Down Expand Up @@ -1176,6 +1193,7 @@ def _process_response(self, response: Response) -> ChatMessage:

def _process_stream_response(self, response: Response) -> Iterator[ChatMessage]:
"""Process a streaming response from the api.
Args:
response: An iterable request Response object
Yields:
Expand Down Expand Up @@ -1381,6 +1399,7 @@ async def _process_response_async(
self, response_dict: Dict[str, Any]
) -> ChatMessage:
"""Process a non streaming response from the api.
Args:
response: A request Response object
Returns:
Expand Down Expand Up @@ -1427,6 +1446,7 @@ def chat(
**kwargs: Any,
) -> ChatResponse:
"""Calls the chat implementation of the SambaStudio model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
Expand All @@ -1435,6 +1455,7 @@ def chat(
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
Returns:
ChatResponse with model generation
"""
Expand All @@ -1457,6 +1478,7 @@ def stream_chat(
**kwargs: Any,
) -> ChatResponseGen:
"""Stream the output of the SambaStudio model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
Expand All @@ -1465,6 +1487,7 @@ def stream_chat(
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
Yields:
chunk: ChatResponseGen with model partial generation
"""
Expand Down Expand Up @@ -1495,6 +1518,7 @@ async def achat(
**kwargs: Any,
) -> ChatResponse:
"""Calls the chat implementation of the SambaStudio model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
Expand All @@ -1503,6 +1527,7 @@ async def achat(
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
Returns:
ChatResponse with model generation
"""
Expand Down Expand Up @@ -1536,4 +1561,4 @@ def astream_complete(
) -> CompletionResponseAsyncGen:
raise NotImplementedError(
"SambaStudio does not currently support async streaming."
)
)

0 comments on commit 1a5588a

Please sign in to comment.