Skip to content

Commit

Permalink
Merge pull request #38 from philschmid/add-bedrock-client
Browse files Browse the repository at this point in the history
Add bedrock client
  • Loading branch information
philschmid authored Oct 26, 2023
2 parents cbd908b + 30d69f6 commit d461250
Show file tree
Hide file tree
Showing 13 changed files with 1,020 additions and 9 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ EasyLLM implements clients that are **compatible with OpenAI's Completion API**.
* `sagemaker.ChatCompletion` - Chat with LLMs
* `sagemaker.Completion` - Text completion with LLMs
* `sagemaker.Embedding` - Create embeddings with LLMs
* `bedrock` - Amazon Bedrock LLMs


Check out the [Examples](./examples) to get started.

Expand Down
78 changes: 78 additions & 0 deletions docs/clients/bedrock.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Amazon Bedrock

EasyLLM provides a client for interfacing with Amazon Bedrock models.

- `bedrock.ChatCompletion` - a client for interfacing with Bedrock models that are compatible with the OpenAI ChatCompletion API.
- `bedrock.Completion` - a client for interfacing with Bedrock models that are compatible with the OpenAI Completion API.
- `bedrock.Embedding` - a client for interfacing with Bedrock models that are compatible with the OpenAI Embedding API.

## `bedrock.ChatCompletion`

The `bedrock.ChatCompletion` client is used to interface with Bedrock models running on Text Generation inference that are compatible with the OpenAI ChatCompletion API. Checkout the [Examples](../examples/bedrock-chat-completion-api)


```python
import os
# set env for prompt builder
os.environ["BEDROCK_PROMPT"] = "anthropic" # vicuna, wizardlm, stablebeluga, open_assistant
os.environ["AWS_REGION"] = "us-east-1" # change to your region
# os.environ["AWS_ACCESS_KEY_ID"] = "XXX" # needed if not using boto3 session
# os.environ["AWS_SECRET_ACCESS_KEY"] = "XXX" # needed if not using boto3 session

from easyllm.clients import bedrock

response = bedrock.ChatCompletion.create(
model="anthropic.claude-v2",
messages=[
{"role": "user", "content": "What is 2 + 2?"},
],
temperature=0.9,
top_p=0.6,
max_tokens=1024,
debug=False,
)
```


Supported parameters are:

* `model` - The model to use for the completion. If not provided, defaults to the base url.
* `messages` - `List[ChatMessage]` to use for the completion.
* `temperature` - The temperature to use for the completion. Defaults to 0.9.
* `top_p` - The top_p to use for the completion. Defaults to 0.6.
* `top_k` - The top_k to use for the completion. Defaults to 10.
* `n` - The number of completions to generate. Defaults to 1.
* `max_tokens` - The maximum number of tokens to generate. Defaults to 1024.
* `stop` - The stop sequence(s) to use for the completion. Defaults to None.
* `stream` - Whether to stream the completion. Defaults to False.
* `frequency_penalty` - The frequency penalty to use for the completion. Defaults to 1.0.
* `debug` - Whether to enable debug logging. Defaults to False.


### Build Prompt

By default the `bedrock` client will try to read the `BEDROCK_PROMPT` environment variable and tries to map the value to the `PROMPT_MAPPING` dictionary. If this is not set, it will use the default prompt builder.
You can also set it manually.

Checkout the [Prompt Utils](../prompt_utils) for more details.


manually setting the prompt builder:

```python
from easyllm.clients import bedrock

bedrock.prompt_builder = "anthropic"

res = bedrock.ChatCompletion.create(...)
```

Using environment variable:

```python
# can happen elsehwere
import os
os.environ["BEDROCK_PROMPT"] = "anthropic"

from easyllm.clients import bedrock
```
35 changes: 29 additions & 6 deletions docs/prompt_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ The `prompt_utils` module contains functions to assist with converting Message'

Supported prompt formats:

* [Llama 2](#llama-2-chat-builder)
* [Vicuna](#vicuna-chat-builder)
* [Hugging Face ChatML](#hugging-face-chatml-builder)
* [WizardLM](#wizardlm-chat-builder)
* [StableBeluga2](#stablebeluga2-chat-builder)
* [Open Assistant](#open-assistant-chat-builder)
- [Prompt utilities](#prompt-utilities)
- [Set prompt builder for client](#set-prompt-builder-for-client)
- [Llama 2 Chat builder](#llama-2-chat-builder)
- [Vicuna Chat builder](#vicuna-chat-builder)
- [Hugging Face ChatML builder](#hugging-face-chatml-builder)
- [StarChat](#starchat)
- [Falcon](#falcon)
- [WizardLM Chat builder](#wizardlm-chat-builder)
- [StableBeluga2 Chat builder](#stablebeluga2-chat-builder)
- [Open Assistant Chat builder](#open-assistant-chat-builder)
- [Anthropic Claude Chat builder](#anthropic-claude-chat-builder)

Prompt utils are also exporting a mapping dictionary `PROMPT_MAPPING` that maps a model name to a prompt builder function. This can be used to select the correct prompt builder function via an environment variable.

Expand Down Expand Up @@ -152,3 +157,21 @@ messages=[
prompt = build_open_assistant_prompt(messages)
```

## Anthropic Claude Chat builder

Creates Anthropic Claude template. Uses `\n\nHuman:`, `\n\nAssistant:`. If a . If a `Message` with an unsupported `role` is passed, an error will be thrown. [Reference](https://docs.anthropic.com/claude/docs/introduction-to-prompt-design)

Example Models:

* [Bedrock](https://aws.amazon.com/bedrock/claude/)

```python
from easyllm.prompt_utils import build_anthropic_prompt

messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Explain asynchronous programming in the style of the pirate Blackbeard."},
]
prompt = build_anthropic_prompt(messages)
```

2 changes: 1 addition & 1 deletion easyllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: 2023-present philschmid <[email protected]>
#
# SPDX-License-Identifier: MIT
__version__ = "0.6.0.dev0"
__version__ = "0.6.0"
219 changes: 219 additions & 0 deletions easyllm/clients/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import json
import logging
import os
from typing import Any, Dict, List, Optional

from nanoid import generate

from easyllm.prompt_utils.base import build_prompt, buildBasePrompt
from easyllm.schema.base import ChatMessage, Usage, dump_object
from easyllm.schema.openai import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
DeltaMessage,
)
from easyllm.utils import setup_logger
from easyllm.utils.aws import get_bedrock_client

logger = setup_logger()

# default parameters
api_type = "bedrock"
api_aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", None)
api_aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None)
api_aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None)

client = get_bedrock_client(
aws_access_key_id=api_aws_access_key,
aws_secret_access_key=api_aws_secret_key,
aws_session_token=api_aws_session_token,
)


SUPPORTED_MODELS = [
"anthropic.claude-v2",
]
model_version_mapping = {"anthropic.claude-v2": "bedrock-2023-05-31"}

api_version = os.environ.get("BEDROCK_API_VERSION", None) or "bedrock-2023-05-31"
prompt_builder = os.environ.get("BEDROCK_PROMPT", None)
stop_sequences = []


def stream_chat_request(client, body, model):
"""Utility function for streaming chat requests."""
id = f"hf-{generate(size=10)}"
response = client.invoke_model_with_response_stream(
body=json.dumps(body), modelId=model, accept="application/json", contentType="application/json"
)
stream = response.get("body")

yield dump_object(
ChatCompletionStreamResponse(
id=id,
model=model,
choices=[ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"))],
)
)
# yield each generated token
reason = None
for _idx, event in enumerate(stream):
chunk = event.get("chunk")
if chunk:
chunk_obj = json.loads(chunk.get("bytes").decode())
text = chunk_obj["completion"]
yield dump_object(
ChatCompletionStreamResponse(
id=id,
model=model,
choices=[ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(content=text))],
)
)
yield dump_object(
ChatCompletionStreamResponse(
id=id,
model=model,
choices=[ChatCompletionResponseStreamChoice(index=0, finish_reason=reason, delta={})],
)
)


class ChatCompletion:
@staticmethod
def create(
messages: List[ChatMessage],
model: Optional[str] = None,
temperature: float = 0.9,
top_p: float = 0.6,
top_k: Optional[int] = 10,
n: int = 1,
max_tokens: int = 1024,
stop: Optional[List[str]] = None,
stream: bool = False,
frequency_penalty: Optional[float] = 1.0,
debug: bool = False,
) -> Dict[str, Any]:
"""
Creates a new chat completion for the provided messages and parameters.
Args:
messages (`List[ChatMessage]`): to use for the completion.
model (`str`, *optional*, defaults to None): The model to use for the completion. If not provided,
defaults to the base url.
temperature (`float`, defaults to 0.9): The temperature to use for the completion.
top_p (`float`, defaults to 0.6): The top_p to use for the completion.
top_k (`int`, *optional*, defaults to 10): The top_k to use for the completion.
n (`int`, defaults to 1): The number of completions to generate.
max_tokens (`int`, defaults to 1024): The maximum number of tokens to generate.
stop (`List[str]`, *optional*, defaults to None): The stop sequence(s) to use for the completion.
stream (`bool`, defaults to False): Whether to stream the completion.
frequency_penalty (`float`, *optional*, defaults to 1.0): The frequency penalty to use for the completion.
debug (`bool`, defaults to False): Whether to enable debug logging.
Tip: Prompt builder
Make sure to always use a prompt builder for your model.
"""
if debug:
logger.setLevel(logging.DEBUG)

# validate it model is in model_mapping
if model not in SUPPORTED_MODELS:
raise ValueError(f"Model {model} is not supported. Supported models are: {SUPPORTED_MODELS}")

request = ChatCompletionRequest(
messages=messages,
model=model,
temperature=temperature,
top_p=top_p,
top_k=top_k,
n=n,
max_tokens=max_tokens,
stop=stop,
stream=stream,
frequency_penalty=frequency_penalty,
)

if prompt_builder is None:
logger.warn(
f"""huggingface.prompt_builder is not set.
Using default prompt builder for. Prompt sent to model will be:
----------------------------------------
{buildBasePrompt(request.messages)}.
----------------------------------------
If you want to use a custom prompt builder, set bedrock.prompt_builder to a function that takes a list of messages and returns a string.
You can also use existing prompt builders by importing them from easyllm.prompt_utils"""
)
prompt = buildBasePrompt(request.messages)
else:
prompt = build_prompt(request.messages, prompt_builder)

# create stop sequences
if isinstance(request.stop, list):
stop = stop_sequences + request.stop
elif isinstance(request.stop, str):
stop = stop_sequences + [request.stop]
else:
stop = stop_sequences
logger.debug(f"Stop sequences:\n{stop}")

# check if we can stream
if request.stream is True and request.n > 1:
raise ValueError("Cannot stream more than one completion")

# construct body
body = {
"prompt": prompt,
"max_tokens_to_sample": request.max_tokens,
"temperature": request.temperature,
"top_k": request.top_k,
"top_p": request.top_p,
"stop_sequences": stop,
"anthropic_version": model_version_mapping[model],
}
logger.debug(f"Generation body:\n{body}")

if request.stream:
return stream_chat_request(client, body, model)
else:
choices = []
generated_tokens = 0
for _i in range(request.n):
response = client.invoke_model(
body=json.dumps(body), modelId=model, accept="application/json", contentType="application/json"
)
# parse response
res = json.loads(response.get("body").read())

# convert to schema
parsed = ChatCompletionResponseChoice(
index=_i,
message=ChatMessage(role="assistant", content=res["completion"].strip()),
finish_reason=res["stop_reason"],
)
generated_tokens += len(res["completion"].strip()) / 4
choices.append(parsed)
logger.debug(f"Response at index {_i}:\n{parsed}")
# calculate usage details
# TODO: fix when details is fixed
prompt_tokens = int(len(prompt) / 4)
total_tokens = prompt_tokens + generated_tokens

return dump_object(
ChatCompletionResponse(
model=request.model,
choices=choices,
usage=Usage(
prompt_tokens=prompt_tokens, completion_tokens=generated_tokens, total_tokens=total_tokens
),
)
)

@classmethod
async def acreate(cls, *args, **kwargs):
"""
Creates a new chat completion for the provided messages and parameters.
"""
raise NotImplementedError("ChatCompletion.acreate is not implemented")
3 changes: 3 additions & 0 deletions easyllm/prompt_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from easyllm.prompt_utils.anthropic import anthropic_stop_sequences, build_anthropic_prompt

from .chatml_hf import (
build_chatml_falcon_prompt,
build_chatml_starchat_prompt,
Expand All @@ -20,4 +22,5 @@
"vicuna": build_vicuna_prompt,
"wizardlm": build_wizardlm_prompt,
"falcon": build_falcon_prompt,
"anthropic": build_anthropic_prompt,
}
Loading

0 comments on commit d461250

Please sign in to comment.