diff --git a/README.md b/README.md index 2c4c095..b6291e6 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ EasyLLM implements clients that are **compatible with OpenAI's Completion API**. * `sagemaker.ChatCompletion` - Chat with LLMs * `sagemaker.Completion` - Text completion with LLMs * `sagemaker.Embedding` - Create embeddings with LLMs +* `bedrock` - Amazon Bedrock LLMs + Check out the [Examples](./examples) to get started. diff --git a/docs/clients/bedrock.md b/docs/clients/bedrock.md new file mode 100644 index 0000000..5de328e --- /dev/null +++ b/docs/clients/bedrock.md @@ -0,0 +1,78 @@ +# Amazon Bedrock + +EasyLLM provides a client for interfacing with Amazon Bedrock models. + +- `bedrock.ChatCompletion` - a client for interfacing with Bedrock models that are compatible with the OpenAI ChatCompletion API. +- `bedrock.Completion` - a client for interfacing with Bedrock models that are compatible with the OpenAI Completion API. +- `bedrock.Embedding` - a client for interfacing with Bedrock models that are compatible with the OpenAI Embedding API. + +## `bedrock.ChatCompletion` + +The `bedrock.ChatCompletion` client is used to interface with Bedrock models running on Text Generation inference that are compatible with the OpenAI ChatCompletion API. Checkout the [Examples](../examples/bedrock-chat-completion-api) + + +```python +import os +# set env for prompt builder +os.environ["BEDROCK_PROMPT"] = "anthropic" # vicuna, wizardlm, stablebeluga, open_assistant +os.environ["AWS_REGION"] = "us-east-1" # change to your region +# os.environ["AWS_ACCESS_KEY_ID"] = "XXX" # needed if not using boto3 session +# os.environ["AWS_SECRET_ACCESS_KEY"] = "XXX" # needed if not using boto3 session + +from easyllm.clients import bedrock + +response = bedrock.ChatCompletion.create( + model="anthropic.claude-v2", + messages=[ + {"role": "user", "content": "What is 2 + 2?"}, + ], + temperature=0.9, + top_p=0.6, + max_tokens=1024, + debug=False, +) +``` + + +Supported parameters are: + +* `model` - The model to use for the completion. If not provided, defaults to the base url. +* `messages` - `List[ChatMessage]` to use for the completion. +* `temperature` - The temperature to use for the completion. Defaults to 0.9. +* `top_p` - The top_p to use for the completion. Defaults to 0.6. +* `top_k` - The top_k to use for the completion. Defaults to 10. +* `n` - The number of completions to generate. Defaults to 1. +* `max_tokens` - The maximum number of tokens to generate. Defaults to 1024. +* `stop` - The stop sequence(s) to use for the completion. Defaults to None. +* `stream` - Whether to stream the completion. Defaults to False. +* `frequency_penalty` - The frequency penalty to use for the completion. Defaults to 1.0. +* `debug` - Whether to enable debug logging. Defaults to False. + + +### Build Prompt + +By default the `bedrock` client will try to read the `BEDROCK_PROMPT` environment variable and tries to map the value to the `PROMPT_MAPPING` dictionary. If this is not set, it will use the default prompt builder. +You can also set it manually. + +Checkout the [Prompt Utils](../prompt_utils) for more details. + + +manually setting the prompt builder: + +```python +from easyllm.clients import bedrock + +bedrock.prompt_builder = "anthropic" + +res = bedrock.ChatCompletion.create(...) +``` + +Using environment variable: + +```python +# can happen elsehwere +import os +os.environ["BEDROCK_PROMPT"] = "anthropic" + +from easyllm.clients import bedrock +``` \ No newline at end of file diff --git a/docs/prompt_utils.md b/docs/prompt_utils.md index f587a90..c6b0898 100644 --- a/docs/prompt_utils.md +++ b/docs/prompt_utils.md @@ -4,12 +4,17 @@ The `prompt_utils` module contains functions to assist with converting Message' Supported prompt formats: -* [Llama 2](#llama-2-chat-builder) -* [Vicuna](#vicuna-chat-builder) -* [Hugging Face ChatML](#hugging-face-chatml-builder) -* [WizardLM](#wizardlm-chat-builder) -* [StableBeluga2](#stablebeluga2-chat-builder) -* [Open Assistant](#open-assistant-chat-builder) +- [Prompt utilities](#prompt-utilities) + - [Set prompt builder for client](#set-prompt-builder-for-client) + - [Llama 2 Chat builder](#llama-2-chat-builder) + - [Vicuna Chat builder](#vicuna-chat-builder) + - [Hugging Face ChatML builder](#hugging-face-chatml-builder) + - [StarChat](#starchat) + - [Falcon](#falcon) + - [WizardLM Chat builder](#wizardlm-chat-builder) + - [StableBeluga2 Chat builder](#stablebeluga2-chat-builder) + - [Open Assistant Chat builder](#open-assistant-chat-builder) + - [Anthropic Claude Chat builder](#anthropic-claude-chat-builder) Prompt utils are also exporting a mapping dictionary `PROMPT_MAPPING` that maps a model name to a prompt builder function. This can be used to select the correct prompt builder function via an environment variable. @@ -152,3 +157,21 @@ messages=[ prompt = build_open_assistant_prompt(messages) ``` +## Anthropic Claude Chat builder + +Creates Anthropic Claude template. Uses `\n\nHuman:`, `\n\nAssistant:`. If a . If a `Message` with an unsupported `role` is passed, an error will be thrown. [Reference](https://docs.anthropic.com/claude/docs/introduction-to-prompt-design) + +Example Models: + +* [Bedrock](https://aws.amazon.com/bedrock/claude/) + +```python +from easyllm.prompt_utils import build_anthropic_prompt + +messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Explain asynchronous programming in the style of the pirate Blackbeard."}, +] +prompt = build_anthropic_prompt(messages) +``` + diff --git a/easyllm/__init__.py b/easyllm/__init__.py index 253e7ff..064927b 100644 --- a/easyllm/__init__.py +++ b/easyllm/__init__.py @@ -1,4 +1,4 @@ # SPDX-FileCopyrightText: 2023-present philschmid # # SPDX-License-Identifier: MIT -__version__ = "0.6.0.dev0" +__version__ = "0.6.0" diff --git a/easyllm/clients/bedrock.py b/easyllm/clients/bedrock.py new file mode 100644 index 0000000..21236d4 --- /dev/null +++ b/easyllm/clients/bedrock.py @@ -0,0 +1,219 @@ +import json +import logging +import os +from typing import Any, Dict, List, Optional + +from nanoid import generate + +from easyllm.prompt_utils.base import build_prompt, buildBasePrompt +from easyllm.schema.base import ChatMessage, Usage, dump_object +from easyllm.schema.openai import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + DeltaMessage, +) +from easyllm.utils import setup_logger +from easyllm.utils.aws import get_bedrock_client + +logger = setup_logger() + +# default parameters +api_type = "bedrock" +api_aws_access_key = os.environ.get("AWS_ACCESS_KEY_ID", None) +api_aws_secret_key = os.environ.get("AWS_SECRET_ACCESS_KEY", None) +api_aws_session_token = os.environ.get("AWS_SESSION_TOKEN", None) + +client = get_bedrock_client( + aws_access_key_id=api_aws_access_key, + aws_secret_access_key=api_aws_secret_key, + aws_session_token=api_aws_session_token, +) + + +SUPPORTED_MODELS = [ + "anthropic.claude-v2", +] +model_version_mapping = {"anthropic.claude-v2": "bedrock-2023-05-31"} + +api_version = os.environ.get("BEDROCK_API_VERSION", None) or "bedrock-2023-05-31" +prompt_builder = os.environ.get("BEDROCK_PROMPT", None) +stop_sequences = [] + + +def stream_chat_request(client, body, model): + """Utility function for streaming chat requests.""" + id = f"hf-{generate(size=10)}" + response = client.invoke_model_with_response_stream( + body=json.dumps(body), modelId=model, accept="application/json", contentType="application/json" + ) + stream = response.get("body") + + yield dump_object( + ChatCompletionStreamResponse( + id=id, + model=model, + choices=[ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"))], + ) + ) + # yield each generated token + reason = None + for _idx, event in enumerate(stream): + chunk = event.get("chunk") + if chunk: + chunk_obj = json.loads(chunk.get("bytes").decode()) + text = chunk_obj["completion"] + yield dump_object( + ChatCompletionStreamResponse( + id=id, + model=model, + choices=[ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(content=text))], + ) + ) + yield dump_object( + ChatCompletionStreamResponse( + id=id, + model=model, + choices=[ChatCompletionResponseStreamChoice(index=0, finish_reason=reason, delta={})], + ) + ) + + +class ChatCompletion: + @staticmethod + def create( + messages: List[ChatMessage], + model: Optional[str] = None, + temperature: float = 0.9, + top_p: float = 0.6, + top_k: Optional[int] = 10, + n: int = 1, + max_tokens: int = 1024, + stop: Optional[List[str]] = None, + stream: bool = False, + frequency_penalty: Optional[float] = 1.0, + debug: bool = False, + ) -> Dict[str, Any]: + """ + Creates a new chat completion for the provided messages and parameters. + + Args: + messages (`List[ChatMessage]`): to use for the completion. + model (`str`, *optional*, defaults to None): The model to use for the completion. If not provided, + defaults to the base url. + temperature (`float`, defaults to 0.9): The temperature to use for the completion. + top_p (`float`, defaults to 0.6): The top_p to use for the completion. + top_k (`int`, *optional*, defaults to 10): The top_k to use for the completion. + n (`int`, defaults to 1): The number of completions to generate. + max_tokens (`int`, defaults to 1024): The maximum number of tokens to generate. + stop (`List[str]`, *optional*, defaults to None): The stop sequence(s) to use for the completion. + stream (`bool`, defaults to False): Whether to stream the completion. + frequency_penalty (`float`, *optional*, defaults to 1.0): The frequency penalty to use for the completion. + debug (`bool`, defaults to False): Whether to enable debug logging. + + Tip: Prompt builder + Make sure to always use a prompt builder for your model. + """ + if debug: + logger.setLevel(logging.DEBUG) + + # validate it model is in model_mapping + if model not in SUPPORTED_MODELS: + raise ValueError(f"Model {model} is not supported. Supported models are: {SUPPORTED_MODELS}") + + request = ChatCompletionRequest( + messages=messages, + model=model, + temperature=temperature, + top_p=top_p, + top_k=top_k, + n=n, + max_tokens=max_tokens, + stop=stop, + stream=stream, + frequency_penalty=frequency_penalty, + ) + + if prompt_builder is None: + logger.warn( + f"""huggingface.prompt_builder is not set. +Using default prompt builder for. Prompt sent to model will be: +---------------------------------------- +{buildBasePrompt(request.messages)}. +---------------------------------------- +If you want to use a custom prompt builder, set bedrock.prompt_builder to a function that takes a list of messages and returns a string. +You can also use existing prompt builders by importing them from easyllm.prompt_utils""" + ) + prompt = buildBasePrompt(request.messages) + else: + prompt = build_prompt(request.messages, prompt_builder) + + # create stop sequences + if isinstance(request.stop, list): + stop = stop_sequences + request.stop + elif isinstance(request.stop, str): + stop = stop_sequences + [request.stop] + else: + stop = stop_sequences + logger.debug(f"Stop sequences:\n{stop}") + + # check if we can stream + if request.stream is True and request.n > 1: + raise ValueError("Cannot stream more than one completion") + + # construct body + body = { + "prompt": prompt, + "max_tokens_to_sample": request.max_tokens, + "temperature": request.temperature, + "top_k": request.top_k, + "top_p": request.top_p, + "stop_sequences": stop, + "anthropic_version": model_version_mapping[model], + } + logger.debug(f"Generation body:\n{body}") + + if request.stream: + return stream_chat_request(client, body, model) + else: + choices = [] + generated_tokens = 0 + for _i in range(request.n): + response = client.invoke_model( + body=json.dumps(body), modelId=model, accept="application/json", contentType="application/json" + ) + # parse response + res = json.loads(response.get("body").read()) + + # convert to schema + parsed = ChatCompletionResponseChoice( + index=_i, + message=ChatMessage(role="assistant", content=res["completion"].strip()), + finish_reason=res["stop_reason"], + ) + generated_tokens += len(res["completion"].strip()) / 4 + choices.append(parsed) + logger.debug(f"Response at index {_i}:\n{parsed}") + # calculate usage details + # TODO: fix when details is fixed + prompt_tokens = int(len(prompt) / 4) + total_tokens = prompt_tokens + generated_tokens + + return dump_object( + ChatCompletionResponse( + model=request.model, + choices=choices, + usage=Usage( + prompt_tokens=prompt_tokens, completion_tokens=generated_tokens, total_tokens=total_tokens + ), + ) + ) + + @classmethod + async def acreate(cls, *args, **kwargs): + """ + Creates a new chat completion for the provided messages and parameters. + """ + raise NotImplementedError("ChatCompletion.acreate is not implemented") diff --git a/easyllm/prompt_utils/__init__.py b/easyllm/prompt_utils/__init__.py index 6c6f767..4dfe8e9 100644 --- a/easyllm/prompt_utils/__init__.py +++ b/easyllm/prompt_utils/__init__.py @@ -1,3 +1,5 @@ +from easyllm.prompt_utils.anthropic import anthropic_stop_sequences, build_anthropic_prompt + from .chatml_hf import ( build_chatml_falcon_prompt, build_chatml_starchat_prompt, @@ -20,4 +22,5 @@ "vicuna": build_vicuna_prompt, "wizardlm": build_wizardlm_prompt, "falcon": build_falcon_prompt, + "anthropic": build_anthropic_prompt, } diff --git a/easyllm/prompt_utils/anthropic.py b/easyllm/prompt_utils/anthropic.py new file mode 100644 index 0000000..374f9dc --- /dev/null +++ b/easyllm/prompt_utils/anthropic.py @@ -0,0 +1,41 @@ +from typing import Dict, List, Union + +from easyllm.schema.base import ChatMessage + +# Define stop sequences for anthropic +anthropic_stop_sequences = ["\n\nUser:", "User:"] + + +def build_anthropic_prompt(messages: Union[List[Dict[str, str]], str, List[ChatMessage]]) -> str: + """ + Builds a anthropic prompt for a chat conversation. refrence https://huggingface.co/blog/anthropic-180b#prompt-format + + Args: + messages (Union[List[ChatMessage], str]): The messages to use for the completion. + Returns: + str: The anthropic prompt string. + """ + ANTHROPIC_USER_TOKEN = "\n\nHuman:" + ANTHROPIC_ASSISTANT_TOKEN = "\n\nAssistant:" + + conversation = [] + + if isinstance(messages, str): + messages = [ChatMessage(content="", role="system"), ChatMessage(content=messages, role="user")] + else: + if isinstance(messages[0], dict): + messages = [ChatMessage(**message) for message in messages] + + for index, message in enumerate(messages): + if message.role == "user": + conversation.append(f"{ANTHROPIC_USER_TOKEN} {message.content.strip()}") + elif message.role == "assistant": + conversation.append(f"{ANTHROPIC_ASSISTANT_TOKEN} {message.content.strip()}") + elif message.role == "function": + raise ValueError("anthropic does not support function calls.") + elif message.role == "system" and index == 0: + conversation.append(message.content) + else: + raise ValueError(f"Invalid message role: {message.role}") + + return "".join(conversation) + ANTHROPIC_ASSISTANT_TOKEN + " " diff --git a/easyllm/schema/openai.py b/easyllm/schema/openai.py index 2187d3b..6961bcb 100644 --- a/easyllm/schema/openai.py +++ b/easyllm/schema/openai.py @@ -26,7 +26,7 @@ class ChatCompletionRequest(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage - finish_reason: Optional[Literal["stop_sequence", "length", "eos_token"]] = None + finish_reason: Optional[Literal["stop_sequence", "length", "eos_token", "max_tokens"]] = None class ChatCompletionResponse(BaseModel): diff --git a/easyllm/utils/__init__.py b/easyllm/utils/__init__.py index 6daa519..2482be5 100644 --- a/easyllm/utils/__init__.py +++ b/easyllm/utils/__init__.py @@ -1,2 +1,2 @@ -from easyllm.utils.aws import AWSSigV4 +from easyllm.utils.aws import AWSSigV4, get_bedrock_client from easyllm.utils.logging import setup_logger diff --git a/easyllm/utils/aws.py b/easyllm/utils/aws.py index c974bfc..b491130 100644 --- a/easyllm/utils/aws.py +++ b/easyllm/utils/aws.py @@ -5,7 +5,9 @@ import os import urllib.parse from datetime import datetime +from typing import Optional +from botocore.config import Config from requests import __version__ as requests_version from requests.auth import AuthBase from requests.compat import urlparse @@ -203,3 +205,69 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: signature, ) return r + + +def get_bedrock_client( + assumed_role: Optional[str] = None, + region: Optional[str] = None, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + runtime: Optional[bool] = True, +): + """Create a boto3 client for Amazon Bedrock, with optional configuration overrides + + Parameters + ---------- + assumed_role : + Optional ARN of an AWS IAM role to assume for calling the Bedrock service. If not + specified, the current active credentials will be used. + region : + Optional name of the AWS Region in which the service should be called (e.g. "us-east-1"). + If not specified, AWS_REGION or AWS_DEFAULT_REGION environment variable will be used. + runtime : + Optional choice of getting different client to perform operations with the Amazon Bedrock service. + """ + if region is None: + target_region = os.environ.get("AWS_REGION", os.environ.get("AWS_DEFAULT_REGION")) + else: + target_region = region + + session_kwargs = {"region_name": target_region} + client_kwargs = {**session_kwargs} + + profile_name = os.environ.get("AWS_PROFILE") + if profile_name: + session_kwargs["profile_name"] = profile_name + + retry_config = Config( + region_name=target_region, + retries={ + "max_attempts": 10, + "mode": "standard", + }, + ) + session = boto3.Session(**session_kwargs) + + if assumed_role: + logger.info(f" Using role: {assumed_role}", end="") + sts = session.client("sts") + response = sts.assume_role(RoleArn=str(assumed_role), RoleSessionName="llm-bedrock") + logger.info(" ... successful!") + client_kwargs["aws_access_key_id"] = response["Credentials"]["AccessKeyId"] + client_kwargs["aws_secret_access_key"] = response["Credentials"]["SecretAccessKey"] + client_kwargs["aws_session_token"] = response["Credentials"]["SessionToken"] + else: + client_kwargs["aws_access_key_id"] = aws_access_key_id + client_kwargs["aws_secret_access_key"] = aws_secret_access_key + client_kwargs["aws_session_token"] = aws_session_token + + if runtime: + service_name = "bedrock-runtime" + else: + service_name = "bedrock" + + bedrock_client = session.client(service_name=service_name, config=retry_config, **client_kwargs) + + logger.info("boto3 Bedrock client successfully created!") + return bedrock_client diff --git a/notebooks/bedrock-chat-completion-api.ipynb b/notebooks/bedrock-chat-completion-api.ipynb new file mode 100644 index 0000000..8a37276 --- /dev/null +++ b/notebooks/bedrock-chat-completion-api.ipynb @@ -0,0 +1,354 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to use Chat Completion clients with Amazon Bedrock\n", + "\n", + "EasyLLM can be used as an abstract layer to replace `gpt-3.5-turbo` and `gpt-4` with Amazon Bedrock models.\n", + "\n", + "You can change your own applications from the OpenAI API, by simply changing the client. \n", + "\n", + "Chat models take a series of messages as input, and return an AI-written message as output.\n", + "\n", + "This guide illustrates the chat format with a few example API calls.\n", + "\n", + "## 0. Setup\n", + "\n", + "Before you can use `easyllm` with Amazon Bedrock you need setup permission and access to the models. You can do this by following of the instructions below:\n", + "* https://docs.aws.amazon.com/IAM/latest/UserGuide/getting-set-up.html\n", + "* https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\n", + "* https://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Import the easyllm library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# if needed, install and/or upgrade to the latest version of the EasyLLM Python library\n", + "%pip install --upgrade easyllm[bedrock] " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# import the EasyLLM Python library for calling the EasyLLM API\n", + "import easyllm" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. An example chat API call\n", + "\n", + "A chat API call has two required inputs:\n", + "- `model`: the name of the model you want to use (e.g., `huggingface-pytorch-tgi-inference-2023-08-08-14-15-52-703`) or leave it empty to just call the api\n", + "- `messages`: a list of message objects, where each object has two required fields:\n", + " - `role`: the role of the messenger (either `system`, `user`, or `assistant`)\n", + " - `content`: the content of the message (e.g., `Write me a beautiful poem`)\n", + "\n", + "Compared to OpenAI api is the `huggingface` module also exposing a `prompt_builder` and `stop_sequences` parameter you can use to customize the prompt and stop sequences. The EasyLLM package comes with prompt builder utilities.\n", + "\n", + "Let's look at an example chat API calls to see how the chat format works in practice." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'completion': ' 2 + 2 = 4', 'stop_reason': 'stop_sequence'}\n" + ] + }, + { + "data": { + "text/plain": [ + "{'id': 'hf-Mf7UqliZQP',\n", + " 'object': 'chat.completion',\n", + " 'created': 1698333425,\n", + " 'model': 'anthropic.claude-v2',\n", + " 'choices': [{'index': 0,\n", + " 'message': {'role': 'assistant', 'content': '2 + 2 = 4'},\n", + " 'finish_reason': 'stop_sequence'}],\n", + " 'usage': {'prompt_tokens': 9, 'completion_tokens': 9, 'total_tokens': 18}}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os \n", + "# set env for prompt builder\n", + "os.environ[\"BEDROCK_PROMPT\"] = \"anthropic\" # vicuna, wizardlm, stablebeluga, open_assistant\n", + "os.environ[\"AWS_REGION\"] = \"us-east-1\" # change to your region\n", + "# os.environ[\"AWS_ACCESS_KEY_ID\"] = \"XXX\" # needed if not using boto3 session\n", + "# os.environ[\"AWS_SECRET_ACCESS_KEY\"] = \"XXX\" # needed if not using boto3 session\n", + "\n", + "from easyllm.clients import bedrock\n", + "\n", + "response = bedrock.ChatCompletion.create(\n", + " model=\"anthropic.claude-v2\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n", + " ],\n", + " temperature=0.9,\n", + " top_p=0.6,\n", + " max_tokens=1024,\n", + " debug=False,\n", + ")\n", + "response\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, the response object has a few fields:\n", + "- `id`: the ID of the request\n", + "- `object`: the type of object returned (e.g., `chat.completion`)\n", + "- `created`: the timestamp of the request\n", + "- `model`: the full name of the model used to generate the response\n", + "- `usage`: the number of tokens used to generate the replies, counting prompt, completion, and total\n", + "- `choices`: a list of completion objects (only one, unless you set `n` greater than 1)\n", + " - `message`: the message object generated by the model, with `role` and `content`\n", + " - `finish_reason`: the reason the model stopped generating text (either `stop`, or `length` if `max_tokens` limit was reached)\n", + " - `index`: the index of the completion in the list of choices" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Extract just the reply with:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2 + 2 = 4\n" + ] + } + ], + "source": [ + "print(response['choices'][0]['message']['content'])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Even non-conversation-based tasks can fit into the chat format, by placing the instruction in the first user message.\n", + "\n", + "For example, to ask the model to explain asynchronous programming in the style of the pirate Blackbeard, we can structure conversation as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'completion': ' Okay class, today we\\'re going to learn about asynchronous programming. Asynchronous means things happening at different times, not necessarily in order. It\\'s like when you\\'re cooking dinner - you might put the pasta on to boil, then start chopping vegetables while the pasta cooks. You don\\'t have to wait for the pasta to finish boiling before you can start on the vegetables. The two tasks are happening asynchronously. \\n\\nIn programming, asynchronous functions allow the code to execute other operations while waiting for a long-running task to complete. Let\\'s look at an example:\\n\\n```js\\nfunction cookPasta() {\\n console.log(\"Putting pasta on to boil...\");\\n // Simulate a long task\\n setTimeout(() => {\\n console.log(\"Pasta done!\");\\n }, 5000); \\n}\\n\\nfunction chopVegetables() {\\n console.log(\"Chopping vegetables...\");\\n}\\n\\ncookPasta();\\nchopVegetables();\\n```\\n\\nWhen we call `cookPasta()`, it starts the timer but doesn\\'t wait 5 seconds - it immediately moves on to calling `chopVegetables()`. So the two functions run asynchronously. \\n\\nThe key is that `cookPasta()` is non-blocking - it doesn\\'t stop the rest of the code from running while it completes. This allows us to maximize efficiency and not waste time waiting.\\n\\nSo in summary, asynchronous programming allows multiple operations to happen independently of each other, like cooking a meal. We avoid blocking code execution by using asynchronous functions. Any questions on this?', 'stop_reason': 'stop_sequence'}\n", + "Okay class, today we're going to learn about asynchronous programming. Asynchronous means things happening at different times, not necessarily in order. It's like when you're cooking dinner - you might put the pasta on to boil, then start chopping vegetables while the pasta cooks. You don't have to wait for the pasta to finish boiling before you can start on the vegetables. The two tasks are happening asynchronously. \n", + "\n", + "In programming, asynchronous functions allow the code to execute other operations while waiting for a long-running task to complete. Let's look at an example:\n", + "\n", + "```js\n", + "function cookPasta() {\n", + " console.log(\"Putting pasta on to boil...\");\n", + " // Simulate a long task\n", + " setTimeout(() => {\n", + " console.log(\"Pasta done!\");\n", + " }, 5000); \n", + "}\n", + "\n", + "function chopVegetables() {\n", + " console.log(\"Chopping vegetables...\");\n", + "}\n", + "\n", + "cookPasta();\n", + "chopVegetables();\n", + "```\n", + "\n", + "When we call `cookPasta()`, it starts the timer but doesn't wait 5 seconds - it immediately moves on to calling `chopVegetables()`. So the two functions run asynchronously. \n", + "\n", + "The key is that `cookPasta()` is non-blocking - it doesn't stop the rest of the code from running while it completes. This allows us to maximize efficiency and not waste time waiting.\n", + "\n", + "So in summary, asynchronous programming allows multiple operations to happen independently of each other, like cooking a meal. We avoid blocking code execution by using asynchronous functions. Any questions on this?\n" + ] + } + ], + "source": [ + "# example with a system message\n", + "response = bedrock.ChatCompletion.create(\n", + " model=\"anthropic.claude-v2\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"Explain asynchronous programming in the style of math teacher.\"},\n", + " ],\n", + ")\n", + "\n", + "print(response['choices'][0]['message']['content'])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'completion': \" Aye matey! Asynchronous programming be when ye fire yer cannons without waiting fer each shot to hit. Ye keep loadin' and shootin' while the cannonballs sail through the air. Ye don't know exactly when they'll strike the target, but ye keep sendin' 'em off. \\n\\nThe ship keeps movin' forward, not stalled waiting fer each blast. Other pirates keep swabbin' the decks and hoistin' the sails so we make progress while the cannons thunder. We tie callbacks to the cannons to handle the boom when they finally hit.\\n\\nArrr! Asynchronous programmin' means ye do lots o' tasks at once, not blocked by waitin' fer each one to finish. Ye move ahead and let functions handle the results when ready. It be faster than linear code that stops at each step. Thar be treasures ahead, lads! Keep those cannons roarin'!\", 'stop_reason': 'stop_sequence'}\n", + "Aye matey! Asynchronous programming be when ye fire yer cannons without waiting fer each shot to hit. Ye keep loadin' and shootin' while the cannonballs sail through the air. Ye don't know exactly when they'll strike the target, but ye keep sendin' 'em off. \n", + "\n", + "The ship keeps movin' forward, not stalled waiting fer each blast. Other pirates keep swabbin' the decks and hoistin' the sails so we make progress while the cannons thunder. We tie callbacks to the cannons to handle the boom when they finally hit.\n", + "\n", + "Arrr! Asynchronous programmin' means ye do lots o' tasks at once, not blocked by waitin' fer each one to finish. Ye move ahead and let functions handle the results when ready. It be faster than linear code that stops at each step. Thar be treasures ahead, lads! Keep those cannons roarin'!\n" + ] + } + ], + "source": [ + "# example without a system message and debug flag on:\n", + "response = bedrock.ChatCompletion.create(\n", + " model=\"anthropic.claude-v2\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"Explain asynchronous programming in the style of the pirate Blackbeard.\"},\n", + " ]\n", + ")\n", + "\n", + "print(response['choices'][0]['message']['content'])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Few-shot prompting\n", + "\n", + "In some cases, it's easier to show the model what you want rather than tell the model what you want.\n", + "\n", + "One way to show the model what you want is with faked example messages.\n", + "\n", + "For example:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'completion': \" Changing direction at the last minute means we don't have time to do an exhaustive analysis for what we're providing to the client.\", 'stop_reason': 'stop_sequence'}\n", + "Changing direction at the last minute means we don't have time to do an exhaustive analysis for what we're providing to the client.\n" + ] + } + ], + "source": [ + "# An example of a faked few-shot conversation to prime the model into translating business jargon to simpler speech\n", + "response = bedrock.ChatCompletion.create(\n", + " model=\"anthropic.claude-v2\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful, pattern-following assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"Help me translate the following corporate jargon into plain English.\"},\n", + " {\"role\": \"assistant\", \"content\": \"Sure, I'd be happy to!\"},\n", + " {\"role\": \"user\", \"content\": \"New synergies will help drive top-line growth.\"},\n", + " {\"role\": \"assistant\", \"content\": \"Things working well together will increase revenue.\"},\n", + " {\"role\": \"user\", \"content\": \"Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage.\"},\n", + " {\"role\": \"assistant\", \"content\": \"Let's talk later when we're less busy about how to do better.\"},\n", + " {\"role\": \"user\", \"content\": \"This late pivot means we don't have time to boil the ocean for the client deliverable.\"},\n", + " ],\n", + ")\n", + "\n", + "print(response[\"choices\"][0][\"message\"][\"content\"])\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Not every attempt at engineering conversations will succeed at first.\n", + "\n", + "If your first attempts fail, don't be afraid to experiment with different ways of priming or conditioning the model.\n", + "\n", + "As an example, one developer discovered an increase in accuracy when they inserted a user message that said \"Great job so far, these have been perfect\" to help condition the model into providing higher quality responses.\n", + "\n", + "For more ideas on how to lift the reliability of the models, consider reading our guide on [techniques to increase reliability](../techniques_to_improve_reliability.md). It was written for non-chat models, but many of its principles still apply." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "openai", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/bedrock-stream-chat-completions.ipynb b/notebooks/bedrock-stream-chat-completions.ipynb new file mode 100644 index 0000000..d72a846 --- /dev/null +++ b/notebooks/bedrock-stream-chat-completions.ipynb @@ -0,0 +1,222 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# How to stream Chat Completion requests with Amazon Bedrock\n", + "\n", + "By default, when you request a completion, the entire completion is generated before being sent back in a single response.\n", + "\n", + "If you're generating long completions, waiting for the response can take many seconds.\n", + "\n", + "To get responses sooner, you can 'stream' the completion as it's being generated. This allows you to start printing or processing the beginning of the completion before the full completion is finished.\n", + "\n", + "To stream completions, set `stream=True` when calling the chat completions or completions endpoints. This will return an object that streams back the response as [data-only server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format). Extract chunks from the `delta` field rather than the `message` field.\n", + "\n", + "## Downsides\n", + "\n", + "Note that using `stream=True` in a production application makes it more difficult to moderate the content of the completions, as partial completions may be more difficult to evaluate. \n", + "\n", + "## Setup\n", + "\n", + "Before you can use `easyllm` with Amazon Bedrock you need setup permission and access to the models. You can do this by following of the instructions below:\n", + "* https://docs.aws.amazon.com/IAM/latest/UserGuide/getting-set-up.html\n", + "* https://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\n", + "* https://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html\n", + "\n", + "## Example code\n", + "\n", + "Below, this notebook shows:\n", + "1. What a typical chat completion response looks like\n", + "2. What a streaming chat completion response looks like\n", + "3. How much time is saved by streaming a chat completion" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# if needed, install and/or upgrade to the latest version of the EasyLLM Python library\n", + "%pip install --upgrade easyllm[bedrock] " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import easyllm # for API calls" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. What a typical chat completion response looks like\n", + "\n", + "With a typical ChatCompletions API call, the response is first computed and then returned all at once." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10/26/2023 17:34:57 - INFO - easyllm.utils.logging - boto3 Bedrock client successfully created!\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334497, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'role': 'assistant'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334498, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' Here'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334498, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' is counting to 100 with a comma'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334498, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' between each number and no newlines:\\n\\n1, 2, 3,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334499, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 4, 5, 6, 7, 8, 9, 10, 11'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334499, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 12, 13, 14, 15, 16, 17, 18,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334499, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334500, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 29, 30, 31, 32, 33, 34, 35, 36, 37, 38,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334500, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334501, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 49, 50, 51'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334501, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 52, 53,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334502, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 54, 55, 56'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334503, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 57, 58, 59, 60, 61'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334504, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 62, 63, 64, 65, 66'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334504, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 67, 68, 69, 70, 71, 72, 73,'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334504, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ' 74, 75, 76, 77, 78, 79, 80, 81'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334505, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 82, 83, 84, 85, 86, 87, 88, 89, 90, 91'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334505, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {'content': ', 92, 93, 94, 95, 96, 97, 98, 99, 100'}}]}\n", + "{'id': 'hf-Je8BGADPWN', 'object': 'chat.completion.chunk', 'created': 1698334505, 'model': 'anthropic.claude-v2', 'choices': [{'index': 0, 'delta': {}}]}\n" + ] + } + ], + "source": [ + "import os \n", + "# set env for prompt builder\n", + "os.environ[\"BEDROCK_PROMPT\"] = \"anthropic\" # vicuna, wizardlm, stablebeluga, open_assistant\n", + "os.environ[\"AWS_REGION\"] = \"us-east-1\" # change to your region\n", + "# os.environ[\"AWS_ACCESS_KEY_ID\"] = \"XXX\" # needed if not using boto3 session\n", + "# os.environ[\"AWS_SECRET_ACCESS_KEY\"] = \"XXX\" # needed if not using boto3 session\n", + "\n", + "from easyllm.clients import bedrock\n", + "\n", + "response = bedrock.ChatCompletion.create(\n", + " model='anthropic.claude-v2',\n", + " messages=[\n", + " {'role': 'user', 'content': 'Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ...'}\n", + " ],\n", + " stream=True\n", + ")\n", + "\n", + "for chunk in response:\n", + " print(chunk)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see above, streaming responses have a `delta` field rather than a `message` field. `delta` can hold things like:\n", + "- a role token (e.g., `{\"role\": \"assistant\"}`)\n", + "- a content token (e.g., `{\"content\": \"\\n\\n\"}`)\n", + "- nothing (e.g., `{}`), when the stream is over" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. How much time is saved by streaming a chat completion\n", + "\n", + "Now let's ask `meta-llama/Llama-2-70b-chat-hf` to count to 100 again, and see how long it takes." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Here is counting to 100 with commas and no newlines:\n", + "\n", + "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100Full conversation received: Here is counting to 100 with commas and no newlines:\n", + "\n", + "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100\n" + ] + } + ], + "source": [ + "import os \n", + "# set env for prompt builder\n", + "os.environ[\"BEDROCK_PROMPT\"] = \"anthropic\" # vicuna, wizardlm, stablebeluga, open_assistant\n", + "os.environ[\"AWS_REGION\"] = \"us-east-1\" # change to your region\n", + "os.environ[\"AWS_PROFILE\"] = \"hf-sm\" # change to your region\n", + "# os.environ[\"AWS_ACCESS_KEY_ID\"] = \"XXX\" # needed if not using boto3 session\n", + "# os.environ[\"AWS_SECRET_ACCESS_KEY\"] = \"XXX\" # needed if not using boto3 session\n", + "from easyllm.clients import bedrock\n", + "\n", + "# send a ChatCompletion request to count to 100\n", + "response = bedrock.ChatCompletion.create(\n", + " model='anthropic.claude-v2',\n", + " messages=[\n", + " {'role': 'user', 'content': 'Count to 100, with a comma between each number and no newlines. E.g., 1, 2, 3, ...'}\n", + " ],\n", + " stream=True\n", + ")\n", + "\n", + "# create variables to collect the stream of chunks\n", + "collected_chunks = []\n", + "collected_messages = []\n", + "# iterate through the stream of events\n", + "for chunk in response:\n", + " collected_chunks.append(chunk) # save the event response\n", + " chunk_message = chunk['choices'][0]['delta'] # extract the message\n", + " print(chunk_message.get('content', ''), end='') # print the message\n", + " collected_messages.append(chunk_message) # save the message\n", + " \n", + "\n", + "# print the time delay and text received\n", + "full_reply_content = ''.join([m.get('content', '') for m in collected_messages])\n", + "print(f\"Full conversation received: {full_reply_content}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.9 ('openai')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index e5da485..952236e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ dependencies = ["pydantic==2.1.1", "nanoid==2.0.0", "huggingface-hub==0.16.4"] [project.optional-dependencies] test = ["pytest", "ruff", "black", "isort", "mypy", "hatch"] +bedrock = ["boto3"] dev = ["ruff", "black", "isort", "mypy", "hatch"] docs = [ "mkdocs",