Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bedrock client #38

Merged
merged 1 commit into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ EasyLLM implements clients that are **compatible with OpenAI's Completion API**.
* `sagemaker.ChatCompletion` - Chat with LLMs
* `sagemaker.Completion` - Text completion with LLMs
* `sagemaker.Embedding` - Create embeddings with LLMs
* `bedrock` - Amazon Bedrock LLMs


Check out the [Examples](./examples) to get started.

Expand Down
78 changes: 78 additions & 0 deletions docs/clients/bedrock.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Amazon Bedrock

EasyLLM provides a client for interfacing with Amazon Bedrock models.

- `bedrock.ChatCompletion` - a client for interfacing with Bedrock models that are compatible with the OpenAI ChatCompletion API.
- `bedrock.Completion` - a client for interfacing with Bedrock models that are compatible with the OpenAI Completion API.
- `bedrock.Embedding` - a client for interfacing with Bedrock models that are compatible with the OpenAI Embedding API.

## `bedrock.ChatCompletion`

The `bedrock.ChatCompletion` client is used to interface with Bedrock models running on Text Generation inference that are compatible with the OpenAI ChatCompletion API. Checkout the [Examples](../examples/bedrock-chat-completion-api)


```python
import os
# set env for prompt builder
os.environ["BEDROCK_PROMPT"] = "anthropic" # vicuna, wizardlm, stablebeluga, open_assistant
os.environ["AWS_REGION"] = "us-east-1" # change to your region
# os.environ["AWS_ACCESS_KEY_ID"] = "XXX" # needed if not using boto3 session
# os.environ["AWS_SECRET_ACCESS_KEY"] = "XXX" # needed if not using boto3 session

from easyllm.clients import bedrock

response = bedrock.ChatCompletion.create(
model="anthropic.claude-v2",
messages=[
{"role": "user", "content": "What is 2 + 2?"},
],
temperature=0.9,
top_p=0.6,
max_tokens=1024,
debug=False,
)
```


Supported parameters are:

* `model` - The model to use for the completion. If not provided, defaults to the base url.
* `messages` - `List[ChatMessage]` to use for the completion.
* `temperature` - The temperature to use for the completion. Defaults to 0.9.
* `top_p` - The top_p to use for the completion. Defaults to 0.6.
* `top_k` - The top_k to use for the completion. Defaults to 10.
* `n` - The number of completions to generate. Defaults to 1.
* `max_tokens` - The maximum number of tokens to generate. Defaults to 1024.
* `stop` - The stop sequence(s) to use for the completion. Defaults to None.
* `stream` - Whether to stream the completion. Defaults to False.
* `frequency_penalty` - The frequency penalty to use for the completion. Defaults to 1.0.
* `debug` - Whether to enable debug logging. Defaults to False.


### Build Prompt

By default the `bedrock` client will try to read the `BEDROCK_PROMPT` environment variable and tries to map the value to the `PROMPT_MAPPING` dictionary. If this is not set, it will use the default prompt builder.
You can also set it manually.

Checkout the [Prompt Utils](../prompt_utils) for more details.


manually setting the prompt builder:

```python
from easyllm.clients import bedrock

bedrock.prompt_builder = "anthropic"

res = bedrock.ChatCompletion.create(...)
```

Using environment variable:

```python
# can happen elsehwere
import os
os.environ["BEDROCK_PROMPT"] = "anthropic"

from easyllm.clients import bedrock
```
35 changes: 29 additions & 6 deletions docs/prompt_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ The `prompt_utils` module contains functions to assist with converting Message'

Supported prompt formats:

* [Llama 2](#llama-2-chat-builder)
* [Vicuna](#vicuna-chat-builder)
* [Hugging Face ChatML](#hugging-face-chatml-builder)
* [WizardLM](#wizardlm-chat-builder)
* [StableBeluga2](#stablebeluga2-chat-builder)
* [Open Assistant](#open-assistant-chat-builder)
- [Prompt utilities](#prompt-utilities)
- [Set prompt builder for client](#set-prompt-builder-for-client)
- [Llama 2 Chat builder](#llama-2-chat-builder)
- [Vicuna Chat builder](#vicuna-chat-builder)
- [Hugging Face ChatML builder](#hugging-face-chatml-builder)
- [StarChat](#starchat)
- [Falcon](#falcon)
- [WizardLM Chat builder](#wizardlm-chat-builder)
- [StableBeluga2 Chat builder](#stablebeluga2-chat-builder)
- [Open Assistant Chat builder](#open-assistant-chat-builder)
- [Anthropic Claude Chat builder](#anthropic-claude-chat-builder)

Prompt utils are also exporting a mapping dictionary `PROMPT_MAPPING` that maps a model name to a prompt builder function. This can be used to select the correct prompt builder function via an environment variable.

Expand Down Expand Up @@ -152,3 +157,21 @@ messages=[
prompt = build_open_assistant_prompt(messages)
```

## Anthropic Claude Chat builder

Creates Anthropic Claude template. Uses `\n\nHuman:`, `\n\nAssistant:`. If a . If a `Message` with an unsupported `role` is passed, an error will be thrown. [Reference](https://docs.anthropic.com/claude/docs/introduction-to-prompt-design)

Example Models:

* [Bedrock](https://aws.amazon.com/bedrock/claude/)

```python
from easyllm.prompt_utils import build_anthropic_prompt

messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain asynchronous programming in the style of the pirate Blackbeard."},
]
prompt = build_anthropic_prompt(messages)
```

2 changes: 1 addition & 1 deletion easyllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023-present philschmid <[email protected]>
#
# SPDX-License-Identifier: MIT
__version__ = "0.6.0.dev0"
__version__ = "0.6.0"
219 changes: 219 additions & 0 deletions easyllm/clients/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import json
import logging
import os
from typing import Any, Dict, List, Optional

from nanoid import generate

from easyllm.prompt_utils.base import build_prompt, buildBasePrompt
from easyllm.schema.base import ChatMessage, Usage, dump_object
from easyllm.schema.openai import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from easyllm.utils import setup_logger
from easyllm.utils.aws import get_bedrock_client

logger = setup_logger()

# default parameters
api_type = "bedrock"
api_aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", None)
api_aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
api_aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None)

client = get_bedrock_client(
aws_access_key_id=api_aws_access_key,
aws_secret_access_key=api_aws_secret_key,
aws_session_token=api_aws_session_token,
)


SUPPORTED_MODELS = [
"anthropic.claude-v2",
]
model_version_mapping = {"anthropic.claude-v2": "bedrock-2023-05-31"}

api_version = os.environ.get("BEDROCK_API_VERSION", None) or "bedrock-2023-05-31"
prompt_builder = os.environ.get("BEDROCK_PROMPT", None)
stop_sequences = []


def stream_chat_request(client, body, model):
"""Utility function for streaming chat requests."""
id = f"hf-{generate(size=10)}"
response = client.invoke_model_with_response_stream(
body=json.dumps(body), modelId=model, accept="application/json", contentType="application/json"
)
stream = response.get("body")

yield dump_object(
ChatCompletionStreamResponse(
id=id,
model=model,
choices=[ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"))],
)
)
# yield each generated token
reason = None
for _idx, event in enumerate(stream):
chunk = event.get("chunk")
if chunk:
chunk_obj = json.loads(chunk.get("bytes").decode())
text = chunk_obj["completion"]
yield dump_object(
ChatCompletionStreamResponse(
id=id,
model=model,
choices=[ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(content=text))],
)
)
yield dump_object(
ChatCompletionStreamResponse(
id=id,
model=model,
choices=[ChatCompletionResponseStreamChoice(index=0, finish_reason=reason, delta={})],
)
)


class ChatCompletion:
@staticmethod
def create(
messages: List[ChatMessage],
model: Optional[str] = None,
temperature: float = 0.9,
top_p: float = 0.6,
top_k: Optional[int] = 10,
n: int = 1,
max_tokens: int = 1024,
stop: Optional[List[str]] = None,
stream: bool = False,
frequency_penalty: Optional[float] = 1.0,
debug: bool = False,
) -> Dict[str, Any]:
"""
Creates a new chat completion for the provided messages and parameters.

Args:
messages (`List[ChatMessage]`): to use for the completion.
model (`str`, *optional*, defaults to None): The model to use for the completion. If not provided,
defaults to the base url.
temperature (`float`, defaults to 0.9): The temperature to use for the completion.
top_p (`float`, defaults to 0.6): The top_p to use for the completion.
top_k (`int`, *optional*, defaults to 10): The top_k to use for the completion.
n (`int`, defaults to 1): The number of completions to generate.
max_tokens (`int`, defaults to 1024): The maximum number of tokens to generate.
stop (`List[str]`, *optional*, defaults to None): The stop sequence(s) to use for the completion.
stream (`bool`, defaults to False): Whether to stream the completion.
frequency_penalty (`float`, *optional*, defaults to 1.0): The frequency penalty to use for the completion.
debug (`bool`, defaults to False): Whether to enable debug logging.

Tip: Prompt builder
Make sure to always use a prompt builder for your model.
"""
if debug:
logger.setLevel(logging.DEBUG)

# validate it model is in model_mapping
if model not in SUPPORTED_MODELS:
raise ValueError(f"Model {model} is not supported. Supported models are: {SUPPORTED_MODELS}")

request = ChatCompletionRequest(
messages=messages,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
n=n,
max_tokens=max_tokens,
stop=stop,
stream=stream,
frequency_penalty=frequency_penalty,
)

if prompt_builder is None:
logger.warn(
f"""huggingface.prompt_builder is not set.
Using default prompt builder for. Prompt sent to model will be:
----------------------------------------
{buildBasePrompt(request.messages)}.
----------------------------------------
If you want to use a custom prompt builder, set bedrock.prompt_builder to a function that takes a list of messages and returns a string.
You can also use existing prompt builders by importing them from easyllm.prompt_utils"""
)
prompt = buildBasePrompt(request.messages)
else:
prompt = build_prompt(request.messages, prompt_builder)

# create stop sequences
if isinstance(request.stop, list):
stop = stop_sequences + request.stop
elif isinstance(request.stop, str):
stop = stop_sequences + [request.stop]
else:
stop = stop_sequences
logger.debug(f"Stop sequences:\n{stop}")

# check if we can stream
if request.stream is True and request.n > 1:
raise ValueError("Cannot stream more than one completion")

# construct body
body = {
"prompt": prompt,
"max_tokens_to_sample": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"stop_sequences": stop,
"anthropic_version": model_version_mapping[model],
}
logger.debug(f"Generation body:\n{body}")

if request.stream:
return stream_chat_request(client, body, model)
else:
choices = []
generated_tokens = 0
for _i in range(request.n):
response = client.invoke_model(
body=json.dumps(body), modelId=model, accept="application/json", contentType="application/json"
)
# parse response
res = json.loads(response.get("body").read())

# convert to schema
parsed = ChatCompletionResponseChoice(
index=_i,
message=ChatMessage(role="assistant", content=res["completion"].strip()),
finish_reason=res["stop_reason"],
)
generated_tokens += len(res["completion"].strip()) / 4
choices.append(parsed)
logger.debug(f"Response at index {_i}:\n{parsed}")
# calculate usage details
# TODO: fix when details is fixed
prompt_tokens = int(len(prompt) / 4)
total_tokens = prompt_tokens + generated_tokens

return dump_object(
ChatCompletionResponse(
model=request.model,
choices=choices,
usage=Usage(
prompt_tokens=prompt_tokens, completion_tokens=generated_tokens, total_tokens=total_tokens
),
)
)

@classmethod
async def acreate(cls, *args, **kwargs):
"""
Creates a new chat completion for the provided messages and parameters.
"""
raise NotImplementedError("ChatCompletion.acreate is not implemented")
3 changes: 3 additions & 0 deletions easyllm/prompt_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from easyllm.prompt_utils.anthropic import anthropic_stop_sequences, build_anthropic_prompt

from .chatml_hf import (
build_chatml_falcon_prompt,
build_chatml_starchat_prompt,
Expand All @@ -20,4 +22,5 @@
"vicuna": build_vicuna_prompt,
"wizardlm": build_wizardlm_prompt,
"falcon": build_falcon_prompt,
"anthropic": build_anthropic_prompt,
}
Loading