Skip to content

Commit

Permalink
Transform Messages Capability (#1923)
Browse files Browse the repository at this point in the history
* wip

* Adds docstrings

* fixed spellings

* wip

* fixed errors

* better class names

* adds tests

* added tests to workflow

* improved token counting

* improved notebook

* improved token counting in test

* improved docstrings

* fix inconsistencies

* changed by mistake

* fixed docstring

* fixed details

* improves tests + adds openai contrib test

* fix spelling oai contrib test

* clearer docstrings

* remove repeated docstr

* improved notebook

* adds metadata to notebook

* Improve outline and description (#2125)

* better dir structure

* clip max tokens to allowed tokens

* more accurate comments/docstrs

* add deperecation warning

* fix front matter desc

* add deperecation warning notebook

* undo local notebook settings changes

* format notebook

* format workflow

---------

Co-authored-by: gagb <[email protected]>
  • Loading branch information
WaelKarkoub and gagb authored Mar 28, 2024
1 parent 6fbb9e8 commit 8cb5c14
Show file tree
Hide file tree
Showing 7 changed files with 1,010 additions and 4 deletions.
37 changes: 37 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,40 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

TransformMessages:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-2019]
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for Transform Messages
run: |
pip install -e .
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/agentchat/contrib/capabilities/test_transform_messages.py --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittest
17 changes: 13 additions & 4 deletions autogen/agentchat/contrib/capabilities/context_handling.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import sys
from termcolor import colored
from typing import Dict, Optional, List
from autogen import ConversableAgent
from autogen import token_count_utils
from typing import Dict, List, Optional
from warnings import warn

import tiktoken
from termcolor import colored

from autogen import ConversableAgent, token_count_utils

warn(
"Context handling with TransformChatHistory is deprecated. "
"Please use TransformMessages from autogen/agentchat/contrib/capabilities/transform_messages.py instead.",
DeprecationWarning,
stacklevel=2,
)


class TransformChatHistory:
Expand Down
90 changes: 90 additions & 0 deletions autogen/agentchat/contrib/capabilities/transform_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import copy
from typing import Dict, List

from termcolor import colored

from autogen import ConversableAgent
from .transforms import MessageTransform


class TransformMessages:
"""Agent capability for transforming messages before reply generation.
This capability allows you to apply a series of message transformations to
a ConversableAgent's incoming messages before they are processed for response
generation. This is useful for tasks such as:
- Limiting the number of messages considered for context.
- Truncating messages to meet token limits.
- Filtering sensitive information.
- Customizing message formatting.
To use `TransformMessages`:
1. Create message transformations (e.g., `MessageHistoryLimiter`, `MessageTokenLimiter`).
2. Instantiate `TransformMessages` with a list of these transformations.
3. Add the `TransformMessages` instance to your `ConversableAgent` using `add_to_agent`.
NOTE: Order of message transformations is important. You could get different results based on
the order of transformations.
Example:
```python
from agentchat import ConversableAgent
from agentchat.contrib.capabilities import TransformMessages, MessageHistoryLimiter, MessageTokenLimiter
max_messages = MessageHistoryLimiter(max_messages=2)
truncate_messages = MessageTokenLimiter(max_tokens=500)
transform_messages = TransformMessages(transforms=[max_messages, truncate_messages])
agent = ConversableAgent(...)
transform_messages.add_to_agent(agent)
```
"""

def __init__(self, *, transforms: List[MessageTransform] = []):
"""
Args:
transforms: A list of message transformations to apply.
"""
self._transforms = transforms

def add_to_agent(self, agent: ConversableAgent):
"""Adds the message transformations capability to the specified ConversableAgent.
This function performs the following modifications to the agent:
1. Registers a hook that automatically transforms all messages before they are processed for
response generation.
"""
agent.register_hook(hookable_method="process_all_messages_before_reply", hook=self._transform_messages)

def _transform_messages(self, messages: List[Dict]) -> List[Dict]:
temp_messages = copy.deepcopy(messages)
system_message = None

if messages[0]["role"] == "system":
system_message = copy.deepcopy(messages[0])
temp_messages.pop(0)

for transform in self._transforms:
temp_messages = transform.apply_transform(temp_messages)

if system_message:
temp_messages.insert(0, system_message)

self._print_stats(messages, temp_messages)

return temp_messages

def _print_stats(self, pre_transform_messages: List[Dict], post_transform_messages: List[Dict]):
pre_transform_messages_len = len(pre_transform_messages)
post_transform_messages_len = len(post_transform_messages)

if pre_transform_messages_len < post_transform_messages_len:
print(
colored(
f"Number of messages reduced from {pre_transform_messages_len} to {post_transform_messages_len}.",
"yellow",
)
)
210 changes: 210 additions & 0 deletions autogen/agentchat/contrib/capabilities/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import sys
from typing import Any, Dict, List, Optional, Protocol, Union

import tiktoken
from termcolor import colored

from autogen import token_count_utils


class MessageTransform(Protocol):
"""Defines a contract for message transformation.
Classes implementing this protocol should provide an `apply_transform` method
that takes a list of messages and returns the transformed list.
"""

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies a transformation to a list of messages.
Args:
messages: A list of dictionaries representing messages.
Returns:
A new list of dictionaries containing the transformed messages.
"""
...


class MessageHistoryLimiter:
"""Limits the number of messages considered by an agent for response generation.
This transform keeps only the most recent messages up to the specified maximum number of messages (max_messages).
It trims the conversation history by removing older messages, retaining only the most recent messages.
"""

def __init__(self, max_messages: Optional[int] = None):
"""
Args:
max_messages (None or int): Maximum number of messages to keep in the context.
Must be greater than 0 if not None.
"""
self._validate_max_messages(max_messages)
self._max_messages = max_messages

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Truncates the conversation history to the specified maximum number of messages.
This method returns a new list containing the most recent messages up to the specified
maximum number of messages (max_messages). If max_messages is None, it returns the
original list of messages unmodified.
Args:
messages (List[Dict]): The list of messages representing the conversation history.
Returns:
List[Dict]: A new list containing the most recent messages up to the specified maximum.
"""
if self._max_messages is None:
return messages

return messages[-self._max_messages :]

def _validate_max_messages(self, max_messages: Optional[int]):
if max_messages is not None and max_messages < 1:
raise ValueError("max_messages must be None or greater than 1")


class MessageTokenLimiter:
"""Truncates messages to meet token limits for efficient processing and response generation.
This transformation applies two levels of truncation to the conversation history:
1. Truncates each individual message to the maximum number of tokens specified by max_tokens_per_message.
2. Truncates the overall conversation history to the maximum number of tokens specified by max_tokens.
NOTE: Tokens are counted using the encoder for the specified model. Different models may yield different token
counts for the same text.
NOTE: For multimodal LLMs, the token count may be inaccurate as it does not account for the non-text input
(e.g images).
The truncation process follows these steps in order:
1. Messages are processed in reverse order (newest to oldest).
2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text
and other types of content, only the text content is truncated.
3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count
exceeds this limit, the current message being processed as well as any remaining messages are discarded.
4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the
original message order.
"""

def __init__(
self,
max_tokens_per_message: Optional[int] = None,
max_tokens: Optional[int] = None,
model: str = "gpt-3.5-turbo-0613",
):
"""
Args:
max_tokens_per_message (None or int): Maximum number of tokens to keep in each message.
Must be greater than or equal to 0 if not None.
max_tokens (Optional[int]): Maximum number of tokens to keep in the chat history.
Must be greater than or equal to 0 if not None.
model (str): The target OpenAI model for tokenization alignment.
"""
self._model = model
self._max_tokens_per_message = self._validate_max_tokens(max_tokens_per_message)
self._max_tokens = self._validate_max_tokens(max_tokens)

def apply_transform(self, messages: List[Dict]) -> List[Dict]:
"""Applies token truncation to the conversation history.
Args:
messages (List[Dict]): The list of messages representing the conversation history.
Returns:
List[Dict]: A new list containing the truncated messages up to the specified token limits.
"""
assert self._max_tokens_per_message is not None
assert self._max_tokens is not None

temp_messages = messages.copy()
processed_messages = []
processed_messages_tokens = 0

# calculate tokens for all messages
total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages)

for msg in reversed(temp_messages):
msg["content"] = self._truncate_str_to_tokens(msg["content"])
msg_tokens = _count_tokens(msg["content"])

# If adding this message would exceed the token limit, discard it and all remaining messages
if processed_messages_tokens + msg_tokens > self._max_tokens:
break

# prepend the message to the list to preserve order
processed_messages_tokens += msg_tokens
processed_messages.insert(0, msg)

if total_tokens > processed_messages_tokens:
print(
colored(
f"Truncated {total_tokens - processed_messages_tokens} tokens. Tokens reduced from {total_tokens} to {processed_messages_tokens}",
"yellow",
)
)

return processed_messages

def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]:
if isinstance(contents, str):
return self._truncate_tokens(contents)
elif isinstance(contents, list):
return self._truncate_multimodal_text(contents)
else:
raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}")

def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Truncates text content within a list of multimodal elements, preserving the overall structure."""
tmp_contents = []
for content in contents:
if content["type"] == "text":
truncated_text = self._truncate_tokens(content["text"])
tmp_contents.append({"type": "text", "text": truncated_text})
else:
tmp_contents.append(content)
return tmp_contents

def _truncate_tokens(self, text: str) -> str:
encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer

encoded_tokens = encoding.encode(text)
truncated_tokens = encoded_tokens[: self._max_tokens_per_message]
truncated_text = encoding.decode(truncated_tokens) # Decode back to text

return truncated_text

def _validate_max_tokens(self, max_tokens: Optional[int] = None) -> Optional[int]:
if max_tokens is not None and max_tokens < 0:
raise ValueError("max_tokens and max_tokens_per_message must be None or greater than or equal to 0")

try:
allowed_tokens = token_count_utils.get_max_token_limit(self._model)
except Exception:
print(colored(f"Model {self._model} not found in token_count_utils.", "yellow"))
allowed_tokens = None

if max_tokens is not None and allowed_tokens is not None:
if max_tokens > allowed_tokens:
print(
colored(
f"Max token was set to {max_tokens}, but {self._model} can only accept {allowed_tokens} tokens. Capping it to {allowed_tokens}.",
"yellow",
)
)
return allowed_tokens

return max_tokens if max_tokens is not None else sys.maxsize


def _count_tokens(content: Union[str, List[Dict[str, Any]]]) -> int:
token_count = 0
if isinstance(content, str):
token_count = token_count_utils.count_token(content)
elif isinstance(content, list):
for item in content:
token_count += _count_tokens(item.get("text", ""))
return token_count
4 changes: 4 additions & 0 deletions notebook/agentchat_capability_long_context_handling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
"source": [
"# Handling A Long Context via `TransformChatHistory`\n",
"\n",
"<div class=\"alert alert-warning\" role=\"alert\">\n",
" <strong>Deprecation Notice:</strong> <code>TransformChatHistory</code> is no longer supported. Please use <code>TransformMessages</code> as the new standard method. For the latest examples, visit the notebook at <a href=\"https://github.com/microsoft/autogen/blob/main/notebook/agentchat_transform_messages.ipynb\" target=\"_blank\">notebook/agentchat_transform_messages.ipynb</a>.\n",
"</div>\n",
"\n",
"This notebook illustrates how you can use the `TransformChatHistory` capability to give any `Conversable` agent an ability to handle a long context. \n",
"\n",
"````{=mdx}\n",
Expand Down
Loading

0 comments on commit 8cb5c14

Please sign in to comment.